diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 373e9618e542..d4f3d8b3835c 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -79,6 +79,7 @@ jobs: TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ + TM_PYTHON_VERSIONS="cp311-cp311" \ ./build_tools/python_deploy/build_linux_packages.sh - name: Post issue comment on build failure diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 5d05f0e51e93..03c23e0335d2 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -63,6 +63,7 @@ jobs: uses: actions/checkout@v3 with: submodules: 'true' + fetch-depth: 0 - name: Fetch PyTorch commit hash if: ${{ matrix.os-arch != 'windows-x86_64' }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 9bd30e243a5b..d5ccc2fc48dd 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -36,6 +36,8 @@ jobs: uses: actions/checkout@v3 with: submodules: 'true' + fetch-depth: 0 + - uses: ./.github/actions/setup-build with: cache-suffix: 'release' diff --git a/README.md b/README.md index bc8a9748be51..e273cedea230 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ PyTorch is an open source machine learning framework that facilitates the seamle [MLIR](https://mlir.llvm.org) The MLIR project offers a novel approach for building extensible and reusable compiler architectures, which address the issue of software fragmentation, reduce the cost of developing domain-specific compilers, improve compilation for heterogeneous hardware, and promote compatibility between existing compilers. + [Torch-MLIR](https://github.com/llvm/torch-mlir) Several vendors have adopted MLIR as the middle layer in their systems, enabling them to map frameworks such as PyTorch, JAX, and TensorFlow into MLIR and subsequently lower them to their target hardware. We have observed half a dozen custom lowerings from PyTorch to MLIR, making it easier for hardware vendors to focus on their unique value, rather than needing to implement yet another PyTorch frontend for MLIR. The ultimate aim is to be similar to the current hardware vendors adding LLVM target support, rather than each one implementing Clang or a C++ frontend. diff --git a/docs/long_term_roadmap.md b/docs/long_term_roadmap.md index 62c3b6f94171..0f0940efc32d 100644 --- a/docs/long_term_roadmap.md +++ b/docs/long_term_roadmap.md @@ -246,3 +246,19 @@ for current LTC-based toolchains onto TorchDynamo. This migration will improve the end-user experience since TorchDynamo is more seamless, but it is a end-user-impacting migration nonetheless and we will want to phase it appropriately with the community. + +### End-to-end (E2E) testing + +Torch-MLIR currently maintains its own test suite with +[hundreds of end-to-end tests](https://github.com/llvm/torch-mlir/tree/main/python/torch_mlir_e2e_test/test_suite) +that verify the correctness and completeness of our op lowerings. +These tests are tedious to write, and also sometimes hit corners +of PyTorch's API that aren't usually reachable by user code. +PyTorch already has an [end-to-end op test suite](https://github.com/pytorch/pytorch/blob/ead51864622467acd6835b6da86a166c1a32aa55/torch/testing/_internal/common_methods_invocations.py#L1) +and we should just plug into it. Here is [an example](https://github.com/pytorch/pytorch/blob/ead51864622467acd6835b6da86a166c1a32aa55/test/test_proxy_tensor.py#L1573) of doing so. +Even better, it would be great if TorchDynamo/PyTorch 2.0 +directly provided a way to plug into this. + +Additionally, we can leverage the [`pytorch-jit-paritybench`](https://github.com/jansel/pytorch-jit-paritybench) +to verify our end-to-end correctness on real models. + diff --git a/e2e_testing/main.py b/e2e_testing/main.py index 770d32ca54b9..91ca0c85f95e 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -27,7 +27,14 @@ from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET +from .xfail_sets import ( + LINALG_XFAIL_SET, + STABLEHLO_PASS_SET, + TOSA_PASS_SET, + LTC_XFAIL_SET, + TORCHDYNAMO_XFAIL_SET, + TORCHDYNAMO_CRASHING_SET +) # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests @@ -77,26 +84,33 @@ def main(): if args.config == "linalg": config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET - if args.config == "tosa": + crashing_set = set() + elif args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET - if args.config == "stablehlo": + crashing_set = set() + elif args.config == "stablehlo": config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) xfail_set = all_test_unique_names - STABLEHLO_PASS_SET + crashing_set = set() elif args.config == "native_torch": config = NativeTorchTestConfig() - xfail_set = {} + xfail_set = set() + crashing_set = set() elif args.config == "torchscript": config = TorchScriptTestConfig() - xfail_set = {} + xfail_set = set() + crashing_set = set() elif args.config == "lazy_tensor_core": config = LazyTensorCoreTestConfig() xfail_set = LTC_XFAIL_SET + crashing_set = set() elif args.config == "torchdynamo": config = TorchDynamoTestConfig() xfail_set = TORCHDYNAMO_XFAIL_SET + crashing_set = TORCHDYNAMO_CRASHING_SET - do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []) + do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set) available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None: for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed: diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index bf21e8bae35f..552c5e4c107b 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -27,10 +27,6 @@ "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2D_basic", - # error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic - # https://github.com/llvm/torch-mlir/issues/1859 - "ConvolutionModule2DGroups_basic", - # RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Failed running call_function aten.nll_loss_backward(... # https://github.com/pytorch/pytorch/issues/89630 @@ -65,6 +61,7 @@ "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic", "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DIndexModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", @@ -93,9 +90,148 @@ "ElementwiseAddScalar_NumToTensorFloat_Module_basic", # ERROR: assert isinstance(e, FakeTensor) "RsubInt0d_NumToTensor_Module_basic", + + # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. + "PrimsSqueezeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + + # ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa. + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + + # See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue + # https://github.com/pytorch/pytorch/issues/99752. + # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {} + 'TensorToBoolZeroRank_basic', + 'TensorToBool_basic', + + # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} + 'AtenSubFloatModule_basic', + 'BoolFloatFalseModule_basic', + 'BoolFloatTrueModule_basic', + 'CeilFloatModule_basic', + 'DivFloatModule_basic', + 'GeFloatIntModule_basic', + 'GeFloatModule_basic', + 'GtFloatIntModule_basic', + 'NeFloatIntModule_basic', + 'SubFloatModule_basic', + 'TensorToFloatZeroRank_basic', + 'TensorToFloat_basic', + + # torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} + 'AddIntModule_basic', + 'AtenIntTensorCharDtypeModule_basic', + 'BoolIntFalseModule_basic', + 'BoolIntTrueModule_basic', + 'DivIntModule_basic', + 'EqIntModule_basic', + 'GeIntModule_basic', + 'GtIntModule_basic', + 'MulIntModule_basic', + 'NeIntModule_basic', + 'SqrtIntModule_basic', + 'SubIntModule_basic', + 'TensorToIntZeroRank_basic', + 'TensorToInt_basic', + 'UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic', + 'ViewCollapseDynamicWithAtenSizeIntModule_basic', + + # torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} + 'SortIntListReverse_basic', + + # torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {} + 'SortIntList_basic', + + # torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default + 'AtenFloatScalarModule_basic', + 'AtenIntBoolOpModule_basic', + 'OneHotModule_basic', + 'QuantizedMLP_basic', + 'ScalarImplicitFloatModule_basic', + 'ScalarImplicitIntModule_basic', + + # torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default + 'BincountMinlengthModule_basic', + 'BincountModule_basic', + 'BincountStaticSizeModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool + 'BoolFloatConstantModule_basic', + 'BoolIntConstantModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__ + 'ContainsIntList_False', + 'ContainsIntList_True', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all + 'AllBoolFalseModule_basic', + 'AllBoolTrueModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any + 'AnyBoolFalseModule_basic', + 'AnyBoolTrueModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt + 'SqrtIntConstantModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int + 'AtenIntBoolOpConstFalseModule_basic', + 'AtenIntBoolOpConstTrueModule_basic', + 'IntFloatModule_basic', + 'PowIntFloatModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len + 'LenStrModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel + 'NumelModule_basic', + 'NumelZeroRankModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max + 'PrimMaxIntModule_basic', + + # torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min + 'PrimMinIntModule_basic', + + # empty graph + 'IsFloatingPointFloat_True', + 'IsFloatingPointInt_False', + 'TorchPrimLoopForLikeModule_basic', + 'TorchPrimLoopWhileLikeModule_basic', +} + +# See https://github.com/llvm/torch-mlir/issues/2050 +TORCHDYNAMO_CRASHING_SET = { + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ExpandAsFloatModule_basic", + "ExpandAsIntModule_basic", + "ExpandModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNDynamicModule_basic", + "NumpyTRankNStaticModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStaticModule_basic", + "TModuleRank2_basic", + "ToCopyModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", } STABLEHLO_PASS_SET = { + "ConstantBoolParameterModule_basic", "MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", @@ -124,6 +260,7 @@ "BucketizeTensorStaticModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "DetachModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", @@ -139,6 +276,8 @@ "ElementwiseClampMinModule_basic", "ElementwiseClampMaxModule_basic", "ElementwisePowModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwiseExpModule_basic", "ElementwiseFlattenBroadcastModule_basic", "ElementwiseLeakyReluModule_basic", @@ -222,6 +361,7 @@ "GatherModule_basic", "Gather2DInputModdule_basic", "GatherRandomIndexModule_basic", + "GatherNegativeDimModule_basic", "GeluBackwardModule_basic", "HardswishModule_basic", "HardswishRandomModule_basic", @@ -234,10 +374,13 @@ "IndexSelectTwoIdxModule_basic", "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", + "IndexSelectNegativeDimModule_basic", "LayerNormLastDimModule_basic", "LayerNormModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", "LeakyReluBackwardStaticModule_basic", + "LinalgVectorNormModule_basic", + "LinalgVectorNormKeepDimModule_basic", "MatmulBroadcastBatchDim_basic", "MatmulSingleDynamicBatchDim_basic", "Matmul_3d", @@ -253,6 +396,7 @@ "Mv_basic", "NativeLayerNormModule4D_basic", "NativeLayerNormModule_basic", + "OneHotModule_basic", "PrimsConvertElementTypeModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceSumDimIntListElementTypeBoolModule_basic", @@ -264,6 +408,15 @@ "ReduceSumDimIntListKeepDimIntModule_basic", "ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeIntModule_basic", + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + "ReduceL2NormModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceLN3NormModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalizeModule_basic", "SelectIntModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", "SliceSingleIdxModule_basic", @@ -323,6 +476,8 @@ "RsubIntModule_basic", "RsubIntModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", "SliceStaticModule_basic", "SliceModule_basic", "SliceNegIdxModule_basic", @@ -331,6 +486,12 @@ "SliceStartEqEndModule_basic", "SliceSizeTwoStepModule_basic", "SliceWholeTensorModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", "SqueezeDimModule_static", "SqueezeDimModule_identity", "SqueezeModule_broadcast", @@ -378,6 +539,7 @@ "NewOnesModuleFloat2D_basic", "NewOnesModuleFloat3D_basic", "NewOnesModuleFalsePinMemory_basic", + "NewZerosStaticModuleLayoutStrided_basic", "DropoutEvalIntModule_basic", "DropoutEvalFloatModule_basic", "ContiguousModule_basic", @@ -457,11 +619,21 @@ "AtenRoundIntModule_basic", "TestF16Return_basic", "_LogSoftmaxModuleStable_basic", + "PrimsSqueezeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "MaxPool2dEmptyStrideStaticModule_basic", + "ConstantBoolParameterModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneModule_basic", @@ -565,6 +737,8 @@ "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", "MaxPool2dStaticModule_basic", + "MaxPool2dStaticCeilModeTrueModule_basic", + "ResNet18StaticModule_basic", "ReduceAmaxKeepDim_basic", "NativeLayerNormModule4D_basic", "LayerNormNormalizeOverAllDimsModule_basic", @@ -652,6 +826,7 @@ "MaskedFillTensorIntValueStaticModule_basic", "ElementwiseAddScalarInt64Module_basic", "TensorLiteralModule_basic", + "NewZerosStaticModuleLayoutStrided_basic", "TensorOpaqueLiteralModule_basic", "TypePromotionDifferentCategoryModule_basic", "TypePromotionSameCategoryDifferentWidthModule_basic", @@ -660,6 +835,12 @@ "GatherStaticModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "ElementwiseWhereScalarModule_basic", + "FullLikeModuleFloat3DStatic_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFloat3D_basic", + "MaskedFillScalarDefaultModule_basic", + "NumToTensorFloatModule_basic", "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListFloatModule_basic", @@ -698,11 +879,32 @@ "HardsigmoidRandomModule_basic", "HardswishModule_basic", "HardswishRandomModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", + "ElementwiseLeFloatTensorModule_basic", + "ElementwiseLeIntTensorModule_basic", "FullLikeModuleInt2DStatic_basic", "FullModuleInt3D_basic", "FullModuleFloat2D_basic", + "ElementwiseAbsModule_basic", "RepeatModule_basic", - "ResNet18StaticModule_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", + "PadModule_basic", + "PadWithNoneValModule_basic", + "ElementwiseRemainderScalarModule_Float_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "MoveDimIntModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "DetachModule_basic", } LTC_XFAIL_SET = { @@ -769,8 +971,10 @@ "IndexPutImpl1DIntNonAccumulateModule_basic", "IndexPutImpl2DFloatAccumulateModule_basic", "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DIndexModule_basic", "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", "IndexTensorStaticModule_basic", @@ -838,24 +1042,17 @@ "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", "PrimsConvertElementTypeModule_basic", - "CopyModule_basic", - "CopyWithDifferentDTypesAndSizesModule_basic", - "CopyWithDifferentDTypesModule_basic", - "CopyWithDifferentSizesModule_basic", "ElementwisePreluModule_basic", "VarMeanBiasedModule_basic", "VarMeanUnbiasedModule_basic", "RandnLikeModule_basic", "RandnLikeDtypeModule_basic", - "NewEmptyStridedModuleDefaultDtype_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", "BernoulliPModule_basic", "DropoutTrainModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionNoneModule_basic", - "SliceCopy_Module_basic", - "SliceCopyNegative_Module_basic", "VarBiasedModule_basic", "VarCorrectionAllDimReduceModule_basic", "VarCorrectionEmptyDimModule_basic", @@ -874,5 +1071,12 @@ "VarDimSingleDimModule_basic", "VarDimUnbiasedModule_basic", "VarUnbiasedModule_basic", - "AtenFloatScalarModule_basic" + "AtenFloatScalarModule_basic", + "PrimsSqueezeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "OneHotModule_basic", + "VarMeanDimModule_basic", + "VarMeanDimBiasedModule_basic", } diff --git a/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt b/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt index 2de2d4eba67e..c8f747af79c8 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt +++ b/externals/llvm-external-projects/torch-mlir-dialects/CMakeLists.txt @@ -47,9 +47,9 @@ function(torch_mlir_dialects_target_includes target) # target, when present, is just used for compilation and does not # contribute to the interface properties. # TODO: Normalize this upstream. - target_include_directories(${target} PUBLIC ${_dirs}) + target_include_directories(${target} PUBLIC "${_dirs}") if(TARGET obj.${target}) - target_include_directories(obj.${target} PRIVATE ${_dirs}) + target_include_directories(obj.${target} PRIVATE "${_dirs}") endif() endfunction() diff --git a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index cc048e47b776..f692d95a2152 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -192,6 +192,58 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", }]; } +def TMTensor_SortOp : TMTensor_Op<"sort", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Sort operator"; + let description = [{ + Based on XLA operation semantics, sorts the given `operands` at the given + `dimension` with the given `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs, + I64Attr:$dimension + ); + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + let assemblyFormat = [{ + attr-dict + `dimension` `(` $dimension `)` + (`ins` `(` $inputs^ `:` type($inputs) `)`)? + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ + Value operand(int index) { + return getOutputs()[index]; + } + ShapedType getOperandType(int index) { + return operand(index).getType().cast(); + } + int64_t getOperandRank() { + return getOperandType(0).getRank(); + } + ArrayRef getOperandShape() { + return getOperandType(0).getShape(); + } + + // Method to implement for specifying output range for + // DestinationStyleOpInterface + std::pair getDpsInitsPositionRange() { + std::pair outputsIndexAndLength = + getODSOperandIndexAndLength(1); + return std::make_pair( + outputsIndexAndLength.first, + outputsIndexAndLength.first + outputsIndexAndLength.second); + } + }]; +} + //===----------------------------------------------------------------------===// // Pure ops //===----------------------------------------------------------------------===// diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 6ce9b502f0f4..99eebd7c580c 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -476,6 +476,172 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, return success(); } +//===----------------------------------------------------------------------===// +// SortOp +//===----------------------------------------------------------------------===// + +LogicalResult SortOp::verify() { + Operation *op = getOperation(); + if (getNumInputs()) { + return op->emitOpError("does not expect to take any inputs"); + } + if (getNumOutputs() == 0) { + return op->emitOpError("expected at least one `outs` operand"); + } + + Block &block = getRegion().front(); + size_t numOutputs = getNumOutputs(); + if (block.getNumArguments() != 2 * numOutputs) { + return op->emitOpError("region block should have ") + << 2 * numOutputs << " arguments"; + } + + int64_t rank = getOperandRank(); + int sortDim = getDimension(); + if (sortDim < 0 || sortDim >= rank) { + return op->emitOpError("dimension must be within (0, ") << rank << "]"; + } + + ArrayRef shape = getOperandShape(); + for (auto indexedOperand : llvm::enumerate(getOutputs())) { + int index = indexedOperand.index(); + auto operandType = getOperandType(index); + if (operandType.getRank() != rank) { + return op->emitOpError("expected operand ") + << index << " to be rank " << rank << ", same as other operands"; + } + if (operandType.getShape() != shape) { + return op->emitOpError("expected operand ") + << index << " to have same shape as other operands"; + } + Type elemType = operandType.getElementType(); + for (int i : {2 * index, 2 * index + 1}) { + Type argType = block.getArgument(i).getType(); + if (argType != elemType) { + return op->emitOpError("region block argument #") + << i << " should be of type " << elemType << " but got " + << argType; + } + } + } + + auto yieldOp = cast(block.getTerminator()); + if (yieldOp.getNumOperands() != 1) { + return op->emitOpError("should yield exactly one operand"); + } + auto ty = yieldOp.getOperand(0).getType().dyn_cast(); + if (!ty || ty.getWidth() != 1) { + return op->emitOpError("should yield i1 type"); + } + + return success(); +} + +SmallVector SortOp::getLoopIteratorTypes() { + // All loops except the dimension to sort along are parallel. + SmallVector iteratorTypes(getOperandRank(), + utils::IteratorType::parallel); + iteratorTypes[getDimension()] = utils::IteratorType::reduction; + return iteratorTypes; +} + +SmallVector SortOp::getIterationDomain(OpBuilder &builder) { + int64_t operandRank = getOperandRank(); + SmallVector loopBounds(operandRank); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = operand(0); + for (auto dim : llvm::seq(0, operandRank)) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, source, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, + ValueRange ivs) { + auto sortDim = getDimension(); + SmallVector indices, sortBlkArgs; + indices.append(ivs.begin(), ivs.end()); + // Bubble sort innermost loop. + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + Value ub; + if (getOperandType(0).isDynamicDim(sortDim)) { + ub = b.create(loc, operand(0), sortDim); + } else { + ub = b.create( + loc, getOperandType(0).getDimSize(sortDim)); + } + ub = b.create(loc, ub, one); + auto scfFor = b.create( + loc, zero, ub, one, ValueRange{}, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) { + SmallVector indices(ivs); + Value ivPlusOne = b.create(loc, iv, one); + for (auto output : getOutputOperands()) { + indices[sortDim] = iv; + sortBlkArgs.push_back( + b.create(loc, output->get(), indices)); + indices[sortDim] = ivPlusOne; + sortBlkArgs.push_back( + b.create(loc, output->get(), indices)); + } + }); + + auto &srcBlock = getRegion().front(); + Region ®ion = scfFor.getRegion(); + IRMapping bvm; + { + OpBuilder::InsertionGuard guard(b); + auto &block = region.front(); + b.setInsertionPointToEnd(&block); + for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) { + bvm.map(std::get<0>(it), std::get<1>(it)); + } + for (auto &blockOp : srcBlock.without_terminator()) { + b.clone(blockOp, bvm); + } + } + Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)); + + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToEnd(®ion.front()); + b.create( + loc, cond, + [&](OpBuilder &b, Location loc) { + // Do not swap the pairs if true. + b.create(loc); + }, + [&](OpBuilder &b, Location loc) { + // Swap the pairs if false. + SmallVector indices(ivs.begin(), ivs.end()); + Value ivPlusOne = + b.create(loc, scfFor.getInductionVar(), one); + for (int i = 0, e = getNumOutputs(); i < e; ++i) { + Value v1 = sortBlkArgs[i * 2]; + Value v2 = sortBlkArgs[i * 2 + 1]; + indices[sortDim] = scfFor.getInductionVar(); + b.create(loc, v2, getOutputOperand(i)->get(), + indices); + indices[sortDim] = ivPlusOne; + b.create(loc, v1, getOutputOperand(i)->get(), + indices); + } + b.create(loc); + }); + b.create(loc); + return success(); +} + +bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) { + // All operands of SortOp will be sorted. So, we'll end up loading/storing + // from them - hence setting this utility to always return `true`. + return true; +} + #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ @@ -488,6 +654,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(ScatterOp) +DEFINE_OP_GET_EFFECTS(SortOp) namespace { /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any diff --git a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index a79c0e09f300..36d061f3237e 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -31,7 +31,7 @@ using namespace ::mlir::torch::TMTensor; static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = memref.getType().cast(); auto alloc = b.create( - loc, memrefType, linalg::getDynOperands(loc, memref, b)); + loc, memrefType, linalg::createDynamicDimensions(b, loc, memref)); b.create(loc, memref, alloc); return alloc; } @@ -73,7 +73,8 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, } resultBuffers.push_back(b.create( - loc, memrefType, linalg::getDynOperands(loc, resultTensor, b))); + loc, memrefType, + linalg::createDynamicDimensions(b, loc, resultTensor))); } return success(); } @@ -86,7 +87,7 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, ValueRange outputs) { SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - return tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands); + return cast(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); } /// Generic conversion pattern that matches any TMTensorOp. This avoids template diff --git a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp index 5e65bb28df88..058dfc82db35 100644 --- a/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp +++ b/externals/llvm-external-projects/torch-mlir-dialects/tools/torch-mlir-dialects-opt/torch-mlir-dialects-opt.cpp @@ -43,7 +43,6 @@ int main(int argc, char **argv) { mlir::func::FuncDialect, mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::tensor::TensorDialect>(); - return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry, - /*preloadDialectsInContext=*/false)); + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "MLIR modular optimizer driver\n", registry)); } diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index b5f30bfbe724..3a130f472b3b 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -125,7 +125,7 @@ def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTMTensorPass()"; } -def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "func::FuncOp"> { +def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprogram", "ModuleOp"> { let summary = "Convert recognized TorchConversion ops to MLProgram ops"; let description = [{ Convert TorchConversion ops to mlprogram ops. diff --git a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h index 79d962492bcf..6d14ec92737c 100644 --- a/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h +++ b/include/torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h @@ -10,12 +10,12 @@ #ifndef TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H #define TORCHMLIR_CONVERSION_TORCHCONVERSIONTOMLPROGRAM_TORCHCONVERSIONTOMLPROGRAM_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { namespace torch { -std::unique_ptr> +std::unique_ptr> createConvertTorchConversionToMLProgramPass(); } } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h similarity index 100% rename from lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h rename to include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index adafe173cc62..485160b7e830 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -89,6 +89,17 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype = std::nullopt); +Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, + Value torchOptionalInt, Value builtinInt, + Value defaultValue, Value dimSize); + +// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If +// yes, then computes the final broadcast shape. +void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc, + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 2a43ec66ce9d..357f95fd2179 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -700,6 +700,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ }]; } +def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atan : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atan_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtan_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtan_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ AllowsTypeRefinement, HasValueSemantics, @@ -3464,6 +3509,30 @@ def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [ }]; } +def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$self, + AnyTorchTensorType:$exponent + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPowScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPowScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -3742,6 +3811,34 @@ def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ }]; } +def Torch_AtenRandintOp : Torch_Op<"aten.randint", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$high, + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRandintOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRandintOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenBernoulliTensorOp : Torch_Op<"aten.bernoulli.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -4343,37 +4440,6 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ }]; } -def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$transposed, - AnyTorchListOfTorchIntType:$output_padding, - Torch_IntType:$groups - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 9, 1); - } - void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 9, 1); - } - }]; -} - def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [ AllowsTypeRefinement, HasValueSemantics, @@ -4503,40 +4569,6 @@ def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [ }]; } -def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_BoolType:$transposed, - AnyTorchListOfTorchIntType:$output_padding, - Torch_IntType:$groups, - AnyTorchListOfTorchBoolType:$output_mask - ); - let results = (outs - AnyTorchTensorType:$grad_input, - AnyTorchTensorType:$grad_weight, - AnyTorchTensorType:$grad_bias - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 10, 3); - } - void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 10, 3); - } - }]; -} - def Torch_AtenFlipOp : Torch_Op<"aten.flip", [ AllowsTypeRefinement, HasValueSemantics, @@ -4997,6 +5029,30 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ }]; } +def Torch_AtenMovedimIntOp : Torch_Op<"aten.movedim.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::movedim.int : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$source, + Torch_IntType:$destination + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMovedimIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMovedimIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenBmmOp : Torch_Op<"aten.bmm", [ AllowsTypeRefinement, HasValueSemantics, @@ -5398,6 +5454,90 @@ def Torch_AtenVarMeanOp : Torch_Op<"aten.var_mean", [ }]; } +def Torch_AtenVarMeanDimOp : Torch_Op<"aten.var_mean.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$unbiased, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenVarMeanDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 2); + } + void AtenVarMeanDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 2); + } + }]; +} + +def Torch_AtenNllLoss2dForwardOp : Torch_Op<"aten.nll_loss2d_forward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchTensorType:$total_weight + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLoss2dForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void AtenNllLoss2dForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_AtenNllLoss2dBackwardOp : Torch_Op<"aten.nll_loss2d_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + AnyTorchTensorType:$total_weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNllLoss2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenNllLoss2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ AllowsTypeRefinement, HasValueSemantics, @@ -5584,6 +5724,34 @@ def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_ba }]; } +def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction, + Torch_IntType:$ignore_index, + Torch_FloatType:$label_smoothing + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCrossEntropyLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenCrossEntropyLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, @@ -6238,6 +6406,30 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ }]; } +def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::one_hot : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$num_classes + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenOneHotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenOneHotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -6432,6 +6624,7 @@ def Torch_AtenDetachOp : Torch_Op<"aten.detach", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ @@ -7286,6 +7479,7 @@ def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ @@ -8536,6 +8730,58 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ }]; } +def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterSrcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterValueOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, @@ -9096,6 +9342,32 @@ def Torch_AtenSortIntOp : Torch_Op<"aten.sort.int", [ let hasCanonicalizer = 1; } +def Torch_AtenSortOp : Torch_Op<"aten.sort", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$descending + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSortOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenSortOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ AllowsTypeRefinement, HasValueSemantics, @@ -10914,6 +11186,7 @@ def Torch_PrimDeviceOp : Torch_Op<"prim.device", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_PrimDtypeOp : Torch_Op<"prim.dtype", [ @@ -11269,6 +11542,52 @@ def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [ }]; } +def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `prims::squeeze : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$a, + AnyTorchListOfTorchIntType:$dimensions + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSqueezeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void PrimsSqueezeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `prims::view_of : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$a + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsViewOfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void PrimsViewOfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ HasValueSemantics, AllowsTypeRefinement, diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 4cf27639ab90..8e817374bd5d 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -56,7 +56,12 @@ struct TorchLoweringPipelineOptions // to check for a specific set of legal ops to stop its iteration. ListOption backendLegalOps{ *this, "backend-legal-ops", - llvm::cl::desc("List of ops to be considered legal for the backend.")}; + llvm::cl::desc("List of ops to be considered legal for the backend, such " + "as 'aten.foo'.")}; + + Option extraLibrary{ + *this, "extra-library", + llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -78,10 +83,12 @@ void createTorchSimplificationPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); /// Creates a pipeline that refines shapes of tensor operations in the program. -void createTorchShapeRefinementPipeline(OpPassManager &pm); +void createTorchShapeRefinementPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); /// Creates a pipeline that refines dtype of tensor operations in the program. -void createTorchDtypeRefinementPipeline(OpPassManager &pm); +void createTorchDtypeRefinementPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); std::unique_ptr> createAdjustCallingConventionsPass(); @@ -89,7 +96,8 @@ std::unique_ptr> createRefineTypesPass(); std::unique_ptr> createInlineGlobalSlotsPass(); -std::unique_ptr> createReduceOpVariantsPass(); +std::unique_ptr> +createReduceOpVariantsPass(StringRef extraLibrary); std::unique_ptr> createMaximizeValueSemanticsPass(); @@ -98,16 +106,16 @@ std::unique_ptr> createRefinePublicReturnPass(); std::unique_ptr> createDecomposeComplexOpsPass(ArrayRef legalOps); -std::unique_ptr> createRecomposeComplexOps(); +std::unique_ptr> createRecomposeComplexOpsPass(); -std::unique_ptr> createPreprocessShapeLibraryPass(); - -std::unique_ptr> createReifyShapeCalculationsPass(); +std::unique_ptr> +createReifyShapeCalculationsPass(StringRef extraLibrary); std::unique_ptr> createSimplifyShapeCalculationsPass(); -std::unique_ptr> createReifyDtypeCalculationsPass(); +std::unique_ptr> +createReifyDtypeCalculationsPass(StringRef extraLibrary); std::unique_ptr> createSimplifyDtypeCalculationsPass(); @@ -120,13 +128,16 @@ createEraseModuleInitializerPass(); std::unique_ptr> createLowerToBackendContractPass(int maxIterations, bool decompose, - ArrayRef backendLegalOps); + ArrayRef backendLegalOps, + StringRef extraLibrary); std::unique_ptr> createVerifyBackendContractNoDecompositionsPass(); StringRef getAbstractInterpLibrary(); +static const char kTorchOpPrefix[] = R"(torch.)"; + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 1ee87b36ec27..8369d1d3d185 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -151,7 +151,13 @@ def InlineGlobalSlots : Pass<"torch-inline-global-slots", "ModuleOp"> { def ReduceOpVariants : Pass<"torch-reduce-op-variants", "func::FuncOp"> { let summary = "Reduces variants of ops to a smaller set of ops."; - let constructor = "mlir::torch::Torch::createReduceOpVariantsPass()"; + let constructor = [{ + mlir::torch::Torch::createReduceOpVariantsPass(/*extraLibrary=*/"") + }]; + let options = [ + Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", + "MLIR module for verifying custom op value semantics">, + ]; let description = [{ Replaces ops with other ops to reduce the number of variants that need to be handled elsewhere in the code. @@ -238,9 +244,38 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> { }]; } +def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> { + let summary = "Recompose torch operations that have been decomposed by TorchScript"; + let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()"; + let description = [{ + There are certain ops that TorchScript will split into multiple ops that + prevent optimizations in Torch-MLIR from taking place. In this pass such + sequences of ops are identified and combined into a higher level op, + preserving the original behavior, while allowing new optimizations to happen. + + An example is the handling of the indexing operation in PyTorch. The following + + ``` + input_tensor[1:2, :] = 7 + ``` + + will get split into a series of `slice` ops to get the sub-tensor, then an + in-place copy to overwrite the sub-tensor with the value 7. This type of + pattern prevents the `MaximizeValueSemantics` pass from succeeding. So, + using `RecomposeComplexOps`, the series of slices + copy is identified + and turned into a single `index_put` operation. + }]; +} + def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { let summary = "Reify shape calculations."; - let constructor = "mlir::torch::Torch::createReifyShapeCalculationsPass()"; + let constructor = [{ + mlir::torch::Torch::createReifyShapeCalculationsPass(/*extraLibrary=*/"") + }]; + let options = [ + Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", + "MLIR module for splicing into the shape library">, + ]; let description = [{ }]; } @@ -255,7 +290,13 @@ def SimplifyShapeCalculations : Pass<"torch-simplify-shape-calculations", "func: def ReifyDtypeCalculations : Pass<"torch-reify-dtype-calculations", "ModuleOp"> { let summary = "Reify dtype calculations."; - let constructor = "mlir::torch::Torch::createReifyDtypeCalculationsPass()"; + let constructor = [{ + mlir::torch::Torch::createReifyDtypeCalculationsPass(/*extraLibrary=*/"") + }]; + let options = [ + Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", + "MLIR module for splicing into the dtype library">, + ]; let description = [{ }]; } @@ -291,7 +332,7 @@ def LowerToBackendContract let summary = "Perform simplifications until the backend contract is satisfied."; let constructor = [{ mlir::torch::Torch::createLowerToBackendContractPass( - /*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}) + /*maxIterations=*/10, /*decompose=*/true, /*backendLegalOps=*/{}, /*extraLibrary=*/"") }]; let description = [{ This pass performs the bulk of the lowering of the program's computations @@ -335,7 +376,9 @@ def LowerToBackendContract Option<"decompose", "decompose", "bool", /*default=*/"true", "Decompose ops.">, ListOption<"backendLegalOps", "backend-legal-ops", "std::string", - "List of ops to be considered legal for the backend."> + "List of ops to be considered legal for the backend, such as 'aten.foo'.">, + Option<"extraLibrary", "extra-library", "std::string", /*default=*/"", + "MLIR module for splicing into the abstract interpretation library">, ]; // TODO: Debug why this is needed, even though the input program has func.func diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 4e3f3ceccf59..37aaed9cd704 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -75,6 +75,17 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); + +// Helper function to squeeze the input tensor at given dim. +// Return the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Location loc, int64_t dim, Value input); + +// Helper function to unsqueeze the input tensor at given dim. +// Return the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, Value dim); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 2e98b37d413f..aa832141f1de 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -11,6 +11,7 @@ #define TORCHMLIR_CONVERSION_PASSDETAIL_H #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt index f819ad018909..b89ffbb43f25 100644 --- a/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt +++ b/lib/Conversion/TorchConversionToMLProgram/CMakeLists.txt @@ -12,9 +12,10 @@ add_mlir_conversion_library(TorchMLIRTorchConversionToMLProgram LINK_LIBS PUBLIC MLIRIR - MLIRPass MLIRLinalgDialect + MLIRMLProgramDialect MLIRMathDialect + MLIRPass TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index 839bae3646b7..eab81c2bec18 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -28,10 +28,22 @@ using namespace mlir::torch::TorchConversion; static constexpr StringRef getSeedGobalVarName() { return "global_seed"; } // Declare a tensor global variable for the seed. -static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { - b.setInsertionPointToStart(module.getBody()); +static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b, + ModuleOp module) { + auto globalSeedSymbol = + SymbolTable::lookupSymbolIn(module, getSeedGobalVarName()); + Type elemTy = b.getI64Type(); auto tensorType = RankedTensorType::get({}, elemTy); + + if (globalSeedSymbol) { + auto globalSeed = dyn_cast(globalSeedSymbol); + if (!globalSeed || globalSeed.getType() != tensorType) + return module.emitError("Unexpected type for global seed."); + return success(); + } + + b.setInsertionPointToStart(module.getBody()); b.create( UnknownLoc::get(b.getContext()), /*sym_name=*/getSeedGobalVarName(), @@ -39,6 +51,8 @@ static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) { /*is_mutable=*/true, /*value=*/DenseIntElementsAttr::get(tensorType, {APInt(64, 0)}), /*sym_visibility=*/b.getStringAttr("private")); + + return success(); } namespace { @@ -104,22 +118,27 @@ class ConvertTorchConversionToMLProgram typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - auto module = getOperation()->getParentOfType(); + auto module = getOperation(); OpBuilder b(module.getBodyRegion()); - createGlobalVariableForSeed(b, module); + if (failed(getOrCreateGlobalVariableForSeed(b, module))) + signalPassFailure(); RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + getOperation()->walk( + [this, &target, &frozenPatterns](func::FuncOp function) { + if (failed(applyPartialConversion(function, target, frozenPatterns))) + return signalPassFailure(); + }); } }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::createConvertTorchConversionToMLProgramPass() { return std::make_unique(); } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 1f921dcaadee..968b8d882280 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -296,7 +296,7 @@ class ConvertAtenAddOp : public OpConversionPattern { } // namespace namespace { -template +template class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -305,38 +305,37 @@ class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern { LogicalResult matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - + Location loc = op.getLoc(); + Value result; SmallVector inputListTorchBool; if (!getListConstructElements(op.getSelf(), inputListTorchBool)) { return rewriter.notifyMatchFailure( - op, "Unimplemented input list not constructed from ListConstruct"); - } - SmallVector inputListBool; - for (Value v : inputListTorchBool) { - bool cst; - if (!matchPattern(v, m_TorchConstantBool(&cst))) - return rewriter.notifyMatchFailure( - op, "only support constant bool input list elements"); - inputListBool.push_back(cst); + op, "unimplemented: input list not constructed from ListConstruct"); } - bool result = reductionFunction(inputListBool); - - rewriter.replaceOpWithNewOp( - op, rewriter.getBoolAttr(result)); + SmallVector inputList = getTypeConvertedValues( + rewriter, loc, this->getTypeConverter(), inputListTorchBool); + result = inputList[0]; + for (unsigned i = 1; i < inputList.size(); i++) + result = rewriter.create(loc, result, inputList[i]); + rewriter.replaceOp(op, result); return success(); } }; -class ConvertAtenAnyOp : public ConvertAtenAnyOrAllBoolOp { - using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; +class ConvertAtenAnyOp + : public ConvertAtenAnyOrAllBoolOp { + using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; bool reductionFunction(ArrayRef inputArray) const override { return llvm::any_of(inputArray, [](bool inputListElem) { return inputListElem; }); } }; -class ConvertAtenAllOp : public ConvertAtenAnyOrAllBoolOp { - using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; +class ConvertAtenAllOp + : public ConvertAtenAnyOrAllBoolOp { + using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; bool reductionFunction(ArrayRef inputArray) const override { return llvm::all_of(inputArray, [](bool inputListElem) { return inputListElem; }); @@ -468,6 +467,9 @@ class ConvertTorchToArith : public ConvertTorchToArithBase target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns .add>( diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 293649de5451..5062f84be793 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -33,31 +33,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static Value toPositiveValidDim(ConversionPatternRewriter &rewriter, - Location loc, Value torchOptionalInt, - Value builtinInt, Value defaultValue, - Value dimSize) { - if (torchOptionalInt.getType().isa()) - return defaultValue; - auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); - Value positiveDim = - toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); - // positveDim < 0 ? 0 : positiveDim - Value cst0 = rewriter.create( - loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); - Value predDimSltZero = rewriter.create( - loc, arith::CmpIPredicate::slt, positiveDim, cst0); - Value atLeastZero = - rewriter.create(loc, predDimSltZero, cst0, positiveDim); - // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero - Value sgtDimSize = rewriter.create( - loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); - Value boundedByDimSize = rewriter.create( - loc, sgtDimSize, dimSizeAsInt, atLeastZero); - - return castIntToIndex(rewriter, loc, boundedByDimSize); -} - template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -145,6 +120,13 @@ class ConvertAtenFlattenUsingIntsOp return rewriter.notifyMatchFailure(op, "end_dim must be constant"); auto type = adaptor.getSelf().getType().cast(); auto inputRank = type.getRank(); + if (inputRank == 1) { + // If input rank is equal to 1, then there's no scope for flattening the + // input tensor. + rewriter.replaceOp(op, adaptor.getSelf()); + return success(); + } + auto resultType = getTypeConverter()->convertType(op.getType()).cast(); if (startDim < 0) @@ -874,10 +856,9 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dim must be constant"); auto inputRank = adaptor.getSelf().getType().cast().getRank(); - if (dim < 0) - dim += inputRank + 1; - if (!(0 <= dim && dim <= inputRank)) - return rewriter.notifyMatchFailure(op, "statically invalid"); + dim = toPositiveDim(dim, inputRank + 1); + if (!isValidDim(dim, inputRank + 1)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector reassociationMap(inputRank); // From the perspective of the reassociation map, the situation of @@ -1101,11 +1082,6 @@ class ConvertAtenCatOp : public OpConversionPattern { Location loc = op.getLoc(); TypeConverter *typeConverter = getTypeConverter(); - Value dimValue = op.getDim(); - int64_t dim; - if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) - return op.emitError("unimplemented: dim is not constant"); - // Collect all the tensors to be concatenated. auto tensorList = op.getTensors(); SmallVector tensorsTorchType; @@ -1130,6 +1106,14 @@ class ConvertAtenCatOp : public OpConversionPattern { } int rank = newResultType.getRank(); + Value dimValue = op.getDim(); + int64_t dim; + if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) + return op.emitError("unimplemented: dim is not constant"); + dim = toPositiveDim(dim, rank); + if (!isValidDim(dim, rank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + SmallVector offsets, sizes, strides; sizes.reserve(rank); strides.resize(rank, rewriter.create(loc, 1)); @@ -1138,10 +1122,6 @@ class ConvertAtenCatOp : public OpConversionPattern { for (int i = 0; i < rank; ++i) sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); - dim = toPositiveDim(dim, rank); - if (!isValidDim(dim, rank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - // Calculate the size of the `dim` result dimension by adding the dim size // of each tensor together. Value resultDimSize = sizes[dim]; diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index b40fd22738e2..0aaecb7fbaac 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -79,6 +79,10 @@ class ConvertAtenGatherOp : public OpConversionPattern { int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); + int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value indices = adaptor.getIndex(); Value self = adaptor.getSelf(); @@ -476,6 +480,9 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) return op->emitError("unimplemented: dim is not constant"); + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector resultShape = getTensorSizes(rewriter, loc, input); resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 92ed647ef8b9..d36b8c309daf 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -512,6 +512,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); + SmallVector outputPaddingIntValues; + if (!getListConstructElements(op.getOutputPadding(), + outputPaddingIntValues)) + return rewriter.notifyMatchFailure( + op, "only support output_padding from a list construct"); + outputPaddingIntValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), outputPaddingIntValues); SmallVector strideInts; if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure(op, @@ -620,6 +627,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); + outerSize = rewriter.create( + loc, outerSize, + castIntToIndex(rewriter, loc, outputPaddingIntValues[i])); outerSizes.push_back(outerSize); offsets.push_back(offset); @@ -643,7 +653,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 0; i < numSpacialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], - castIndexToInt(weightDims[i]), strideIntValues[i])); + castIndexToInt(weightDims[i]), strideIntValues[i], + outputPaddingIntValues[i])); // Set stride to 1 strideInts.clear(); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index d7838d1c75b0..b744ea676f92 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -47,8 +47,22 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, } kernelSizeIntValues = getTypeConvertedValues( rewriter, op.getLoc(), typeConverter, kernelSizeTorchInt); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure(op, "only support constant int strides"); + // If `stride` is not specified by the user, it is assigned the value of empty + // list during import. For such a case, the stride value is the kernel size. + // See: + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + if (strideInts.empty()) { + if (!matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(strideInts))) { + return rewriter.notifyMatchFailure( + op, "if stride is the empty list, kernel_size must be a list of " + "constant ints"); + } + } + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure(op, "only support constant int paddings"); @@ -80,7 +94,7 @@ static LogicalResult createPoolingOp( highPaddingIncludingNC[2] += strideInts[0]; highPaddingIncludingNC[3] += strideInts[1]; } - Value initValue = rewriter.create(loc, initValueAttr); + Value initValue = rewriter.create(loc, cast(initValueAttr)); paddedInput = torch_to_linalg::getPaddedTensor( op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, initValue); @@ -154,7 +168,7 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); - auto smallestFPValueAttr = rewriter.getFloatAttr( + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( elementType, APFloat::getLargest( elementType.cast().getFloatSemantics(), diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index da308ce53ecd..cdc440801c14 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -277,7 +277,7 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); auto abs = b.create(loc, self); - Attribute twoAttr = b.getFloatAttr(resultElementType, 2.0); + TypedAttr twoAttr = b.getFloatAttr(resultElementType, 2.0); auto ord = b.create(loc, twoAttr); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); @@ -403,7 +403,7 @@ class ConvertReductionOp : public ConversionPattern { return rewriter.notifyMatchFailure(op, "unimplemented: ord = +/- inf"); // Raise each summed value to the inverse of the order of the norm. - Attribute oneAttr = rewriter.getFloatAttr(elemType, 1.0); + TypedAttr oneAttr = rewriter.getFloatAttr(elemType, 1.0); auto oneValue = rewriter.create(loc, oneAttr); auto inverseOrdValue = rewriter.create(loc, oneValue, ordValue); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 49730d5bf39a..500d1ccac1a8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -229,6 +229,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && @@ -1119,7 +1123,7 @@ class ConvertElementwiseOp : public ConversionPattern { AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp, - AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op)) + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1595,7 +1599,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 58f807e74915..27299458de8b 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -134,7 +134,7 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value torch_to_linalg::getOutputDimForConvTransposeOps( OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, - Value kernelSizeInt, Value strideInt) { + Value kernelSizeInt, Value strideInt, Value outputPaddingInt) { Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); @@ -152,6 +152,7 @@ Value torch_to_linalg::getOutputDimForConvTransposeOps( Value out = b.create(loc, inStrided, doublePadding); out = b.create(loc, out, kernelDilated); + out = b.create(loc, out, outputPaddingInt); out = b.create(loc, out, c1); return castIntToIndex(b, loc, out); diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index b1f343678abd..5fd5538c264b 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -49,10 +49,12 @@ Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in, // As above but for transposed convolution ops // Along each dim: // dim_out = -// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + 1 +// (dim_in - 1) * stride - 2 * padding + dilation * (kernelSize - 1) + +// output_padding + 1 Value getOutputDimForConvTransposeOps(OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, - Value kernelSizeInt, Value strideInt); + Value kernelSizeInt, Value strideInt, + Value outputPaddingInt); // Create a reduction of `opInfo.tensorOperand`, reducing along the dimensions // in `opInfo.dimSet`. If `opInfo.keepDim` is true, the output tensor is the diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index d84fbaf9bf04..bc25c7e64ffc 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -11,7 +11,6 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -23,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "utils/hlo_utils.h" #include #include @@ -577,6 +577,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dimInt; if (matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) { dimInt = toPositiveDim(dimInt, selfType.getRank()); + if (!isValidDim(dimInt, selfType.getRank())) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); dim = rewriter.create(op.getLoc(), dimInt); } else { Value inputRank = rewriter.create( @@ -1189,6 +1191,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "only constant dim param is supported"); } + dim = toPositiveDim(dim, outType.getRank()); + if (!isValidDim(dim, outType.getRank())) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); SmallVector torchTensors; if (!getListConstructElements(op.getTensors(), torchTensors)) { @@ -1203,9 +1208,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( v = hlo::promoteType(rewriter, v, outType); } - size_t posDim = toPositiveDim(dim, outType.getRank()); rewriter.replaceOpWithNewOp( - op, outType, ValueRange(builtinTensors), posDim); + op, outType, ValueRange(builtinTensors), dim); return success(); } @@ -1378,6 +1382,29 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowTensorTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value lhs = adaptor.getSelf(); + auto lhsTy = lhs.getType().cast(); + Value rhs = adaptor.getExponent(); + auto rhsTy = rhs.getType().cast(); + + if (!lhsTy || !rhsTy) + return op.emitError("only Tensor types supported"); + + auto outTy = + this->getTypeConverter()->convertType(op.getType()).cast(); + + lhs = hlo::promoteType(rewriter, lhs, outTy); + rhs = hlo::promoteType(rewriter, rhs, outTy); + + rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, + /*broadcast_attr*/ nullptr); + return success(); +} + // RuntimeAssertOp namespace { class ConvertRuntimeAssertOp : public OpConversionPattern { @@ -1521,6 +1548,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 237512980562..84a560cd753d 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo TorchToStablehlo.cpp StablehloLegalizeUtils.cpp Basic.cpp - Gather.cpp + GatherScatter.cpp Linear.cpp ViewLike.cpp Reduction.cpp diff --git a/lib/Conversion/TorchToStablehlo/Gather.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp similarity index 67% rename from lib/Conversion/TorchToStablehlo/Gather.cpp rename to lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 4373327036c7..0118a8a595f2 100644 --- a/lib/Conversion/TorchToStablehlo/Gather.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -11,7 +11,6 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -21,6 +20,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; @@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, sliceSizesTensor, dimsAttr) .getResult(); } + +template +LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVector &resultShape, + SmallVector &offsets, + SmallVector &strides) { + Location loc = op.getLoc(); + auto input = adaptor.getSelf(); + RankedTensorType inputType = + input.getType().template cast(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return op->emitError("unimplemented: dim is not constant"); + + int64_t inputRank = inputType.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value dimSize = inputShape[dim]; + + Value torchTypeStart = op.getStart(); + Value torchTypeEnd = op.getEnd(); + Value builtinTypeStart = adaptor.getStart(); + Value builtinTypeEnd = adaptor.getEnd(); + + if (torchTypeStart.getType().isa() || + torchTypeEnd.getType().isa()) + return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { + if (!op.getStep().getType().template isa()) + return op->emitError("unimplemented: step is not constant"); + step = 1; + } + + Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, + builtinTypeStart, zero, dimSize); + Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, + dimSize, dimSize); + + // end >= start ? end : start + Value endSgeStart = rewriter.create( + loc, arith::CmpIPredicate::sge, end, start); + end = rewriter.create(loc, endSgeStart, end, start); + Value stepIndex = rewriter.create(loc, step); + + // Slice logic: resultSize = floordiv(end - start + step - 1, step) + resultShape = getTensorSizes(rewriter, loc, input); + Value len = rewriter.create(loc, end, start); + Value resultSize = rewriter.create(loc, len, stepIndex); + resultSize = rewriter.create(loc, resultSize, one); + resultSize = rewriter.create(loc, resultSize, stepIndex); + resultShape[dim] = resultSize; + + strides.resize(inputType.getRank(), one); + offsets.resize(inputType.getRank(), zero); + + offsets[dim] = start; + strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + return success(); +} } // namespace // Ref: @@ -159,6 +228,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); + int64_t inputRank = selfTy.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value output = gatherTensorAlongSingleAxis( rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); @@ -258,9 +331,54 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToStablehloOptions &options) { +// AtenSliceScatterOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + TypeConverter *typeConverter = getTypeConverter(); + + auto input = adaptor.getSelf(); + + RankedTensorType resultType = + typeConverter->convertType(op->getResult(0).getType()) + .cast(); + + SmallVector resultShape; + SmallVector offsets; + SmallVector strides; + if (failed(prepareArgumentsForSlicingOp( + op, adaptor, rewriter, resultShape, offsets, strides))) { + return failure(); + } + + Value src = adaptor.getSrc(); + auto srcType = src.getType().cast(); + int64_t srcRank = srcType.getRank(); + SmallVector srcAbstractSizes(srcRank, kUnknownSize); + auto abstractSrcType = RankedTensorType::get( + makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); + Value abstractSrc = + rewriter.create(loc, abstractSrcType, src); + + Value result = rewriter.create( + loc, abstractSrc, input, offsets, resultShape, strides); + + rewriter.replaceOpWithNewOp(op, resultType, result); + + return success(); +} + +void mlir::torch::torch_to_stablehlo:: + populateGatherScatterOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ @@ -269,5 +387,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index fbc3d6ee4eb8..0786151cb217 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -11,12 +11,12 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -473,21 +473,21 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { SmallVector weightShapeInt(rank); std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); - // 1. [IC, OC, H, W, ...] => [G, IC//G, OC, H, W, ...] + // 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G] Value GValue = rewriter.create( op->getLoc(), rewriter.getI64IntegerAttr(groups)); Value ICDivGValue = rewriter.create( - op->getLoc(), weightShapeVec[0], GValue); + op->getLoc(), weightShapeVec[rank - 1], GValue); Value OCMulGValue = rewriter.create( - op->getLoc(), weightShapeVec[1], GValue); - weightShapeVec[0] = ICDivGValue; - weightShapeVec.insert(weightShapeVec.begin(), GValue); + op->getLoc(), weightShapeVec[rank - 2], GValue); + weightShapeVec[rank - 1] = ICDivGValue; + weightShapeVec.insert(weightShapeVec.end() - 1, GValue); - if (weightShapeInt[0] == ShapedType::kDynamic) { - weightShapeInt.insert(weightShapeInt.begin(), groups); + if (weightShapeInt[rank - 1] == ShapedType::kDynamic) { + weightShapeInt.insert(weightShapeInt.end() - 1, groups); } else { - weightShapeInt[0] /= groups; - weightShapeInt.insert(weightShapeInt.begin(), groups); + weightShapeInt[rank - 1] /= groups; + weightShapeInt.insert(weightShapeInt.end() - 1, groups); } Value weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); @@ -495,21 +495,21 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), weight, weightShapeTensor); - // 2. [G, IC//G, OC, H, W, ...] => [IC//G, G, OC, H, W, ...] + // 2. [H, W, ..., OC, G, IC//G] => [H, W, ..., G, OC, IC//G] std::vector transposeDims(rank + 1); for (int64_t i = 0; i <= rank; i++) transposeDims[i] = i; - std::swap(transposeDims[1], transposeDims[0]); + std::swap(transposeDims[rank - 1], transposeDims[rank - 2]); weight = rewriter.create( op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); - // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] - weightShapeInt.erase(weightShapeInt.begin()); - if (weightShapeInt[1] != ShapedType::kDynamic) { - weightShapeInt[1] *= groups; + // 3. [H, W, ..., G, OC, IC//G] => [H, W, ..., G*OC, IC//G] + weightShapeInt.erase(weightShapeInt.end() - 2); + if (weightShapeInt[weightShapeInt.size() - 2] != ShapedType::kDynamic) { + weightShapeInt[weightShapeInt.size() - 2] *= groups; } - weightShapeVec.erase(weightShapeVec.begin()); - weightShapeVec[1] = OCMulGValue; + weightShapeVec.erase(weightShapeVec.end() - 2); + weightShapeVec[weightShapeVec.size() - 2] = OCMulGValue; weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); weight = rewriter.create( @@ -524,8 +524,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { Value weight, ArrayRef stride, ArrayRef padding, ArrayRef dilation, - ArrayRef outputPadding, int64_t groups, - bool needHandleOutputPadding) const { + ArrayRef outputPadding, + int64_t groups) const { auto inputTy = input.getType().cast(); auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); @@ -534,17 +534,24 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto nSpatialDims = nDims - 2; auto convOutTy = outType; - if (needHandleOutputPadding) { - SmallVector outShape(nDims); - auto finalOutShape = outType.getShape(); - std::copy(finalOutShape.begin(), finalOutShape.end(), outShape.begin()); - for (int i = 2; i < nDims; ++i) { - if (finalOutShape[i] == ShapedType::kDynamic) - continue; - outShape[i] = finalOutShape[i] - outputPadding[i - 2]; - } - convOutTy = RankedTensorType::get(outShape, outType.getElementType()); + // Transpose weight + SmallVector perm(nDims); + SmallVector transposeShape(nDims); + for (int i = 0; i < nDims; i++) { + if (i < 2) + perm[i] = nDims - 2 + i; + else + perm[i] = nDims - i - 1; + transposeShape[i] = weightShape[perm[i]]; } + auto transposeTy = + RankedTensorType::get(transposeShape, weightTy.getElementType()); + DenseIntElementsAttr permAttr = DenseIntElementsAttr::get( + RankedTensorType::get({nDims}, rewriter.getI64Type()), perm); + auto transposeOp = rewriter.create( + op->getLoc(), transposeTy, weight, permAttr); + auto reverseOp = rewriter.create( + op->getLoc(), transposeOp, rewriter.getI64TensorAttr({0, 1})); // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); @@ -554,7 +561,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; stablehloPaddingVec[i * 2] = padInt; - stablehloPaddingVec[i * 2 + 1] = padInt; + stablehloPaddingVec[i * 2 + 1] = + padInt + outputPadding[outputPadding.size() - i - 1]; } DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), @@ -573,58 +581,35 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { ArrayAttr precisionConfig; SmallVector spatialDims; + SmallVector transposedSpatialDims; for (int i = 0; i < nSpatialDims; ++i) { spatialDims.push_back(i + 2); + transposedSpatialDims.push_back(i); } + stablehlo::ConvDimensionNumbersAttr dimensionNumbers = stablehlo::ConvDimensionNumbersAttr::get( /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*inputFeatureDimension=*/1, /*inputSpatialDimensions=*/spatialDims, - /*kernelInputFeatureDimension=*/0, - /*kernelOutputFeatureDimension=*/1, - /*kernelSpatialDimensions=*/spatialDims, + /*kernelInputFeatureDimension=*/nDims - 1, + /*kernelOutputFeatureDimension=*/nDims - 2, + /*kernelSpatialDimensions=*/transposedSpatialDims, /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputSpatialDimensions=*/spatialDims); - // Reverse and transpose weight - weight = rewriter.create( - op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); + Value weightInput = reverseOp.getResult(); if (groups != 1) { - weight = reshapeConvWeight(rewriter, op, weight, groups); + weightInput = reshapeConvWeight(rewriter, op, reverseOp, groups); } // Create transposed convolution auto transposedConvOp = rewriter.create( - op->getLoc(), convOutTy, input, weight, stablehloStride, + op->getLoc(), convOutTy, input, weightInput, stablehloStride, stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, windowReversal, dimensionNumbers, static_cast(groups), 1, precisionConfig); - - // Handle output padding - if (!needHandleOutputPadding) { - return transposedConvOp.getResult(); - } - SmallVector edgePaddingLowVec(nDims, 0); - SmallVector edgePaddingHighVec(nDims, 0); - SmallVector interiorPaddingVec(nDims, 0); - std::copy(outputPadding.begin(), outputPadding.end(), - edgePaddingHighVec.begin() + 2); - Value paddingValue = - hlo::getConstTensor(rewriter, op, {0.0}, {}).value(); - paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy); - mlir::DenseIntElementsAttr edgePaddingLow = - rewriter.getI64VectorAttr(edgePaddingLowVec); - mlir::DenseIntElementsAttr edgePaddingHigh = - rewriter.getI64VectorAttr(edgePaddingHighVec); - mlir::DenseIntElementsAttr interiorPadding = - rewriter.getI64VectorAttr(interiorPaddingVec); - - auto paddedOutput = rewriter.create( - op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, - edgePaddingHigh, interiorPadding); - - return paddedOutput.getResult(); + return transposedConvOp.getResult(); } Value convertNormalConv(AtenConvolutionOp op, @@ -763,9 +748,9 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { Value stablehloConvResult; if (transposed) { - stablehloConvResult = convertTransposedConv( - op, rewriter, outTy, input, weight, stride, padding, dilation, - outputPadding, groups, needHandleOutputPadding); + stablehloConvResult = + convertTransposedConv(op, rewriter, outTy, input, weight, stride, + padding, dilation, outputPadding, groups); } else { stablehloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, stride, padding, diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 90044cc8b81e..f40125165a62 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -11,7 +11,6 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -23,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include #include diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h index b6322efd6897..fc28acfde29f 100644 --- a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h @@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, void populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); -void populateGatherOpPatternsAndLegality( +void populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options); void populateReductionOpPatternsAndLegality( diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index eb4e11116c71..abf98d3ec5ec 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -11,10 +11,10 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; @@ -31,7 +32,8 @@ using namespace mlir::torch::torch_to_stablehlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -597,8 +599,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // AtenFrobeniusNormDimOp // aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given -// dims) -// + stablehlo.sqrt +// dims) + stablehlo.sqrt namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( @@ -702,6 +703,132 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenLinalgVectorNormOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenLinalgVectorNormOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + const TorchToStablehloOptions &options = getOptions(); + + Value input = adaptor.getSelf(); + auto inputType = input.getType().dyn_cast(); + if (!inputType) { + return op.emitError( + "only ranked tensor input supported in AtenLinalgVectorNormOp"); + } + int64_t inputRank = inputType.getRank(); + + auto outType = + getTypeConverter()->convertType(op.getType()).cast(); + auto outElemType = outType.getElementType(); + if (!outElemType.isa()) { + return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp"); + } + + if (inputType.getElementType() != outType.getElementType()) { + input = + rewriter.create(op->getLoc(), input, outElemType); + } + + Value ord = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOrd(), outElemType); + + SmallVector dims; + if (failed(checkNotNone(rewriter, op, op.getDim()))) { + dims = llvm::to_vector<4>(llvm::seq(0, inputRank)); + } else { + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + + for (auto &dim : dims) { + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure( + op, "invalid dimension detected in `dim`"); + } + } + // Sort the dims in ascending order, making the conversion + // stable with unordered dims. + std::sort(dims.begin(), dims.end()); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "non-const bool `keepdim` is not supported"); + } + + auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter); + if (!initValue) { + return failure(); + } + + Value absValue = rewriter.create(op->getLoc(), input); + Value powValue = rewriter.create(op->getLoc(), absValue, + ord, nullptr); + + auto reduceOp = rewriter.create( + op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims)); + + Region ®ion = reduceOp.getBody(); + Block &block = region.emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, outElemType); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto firstArgument = *block.args_begin(); + auto secondArgument = *block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + auto addResult = rewriter.create( + op->getLoc(), firstArgument, secondArgument); + rewriter.create(op->getLoc(), addResult.getResult()); + } + auto constantOne = rewriter.create( + op->getLoc(), blockArgumentTy, + DenseElementsAttr::get( + blockArgumentTy, + APFloat(outElemType.cast().getFloatSemantics(), 1))); + auto reciprocalOrd = rewriter.create( + op->getLoc(), blockArgumentTy, constantOne, ord); + auto output = rewriter.create( + op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr); + + if (keepDim) { + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), output, + outShapeTensor); + return success(); + } + + rewriter.replaceOp(op, output.getResult()); + return success(); +} +} // namespace + void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -715,5 +842,6 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index dbcfba2ff306..785ae50e6b01 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -6,8 +6,7 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// - -#include "StablehloLegalizeUtils.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index ba08384846cc..434d55c760d3 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -65,7 +65,7 @@ class ConvertTorchToStablehlo typeConverter, patterns, target, options); torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( typeConverter, patterns, target, options); - torch_to_stablehlo::populateGatherOpPatternsAndLegality( + torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality( typeConverter, patterns, target, options); torch_to_stablehlo::populateReductionOpPatternsAndLegality( typeConverter, patterns, target, options); diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index b6511c384068..ea19092e6c8b 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -11,7 +11,6 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" -#include "StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -23,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include using namespace mlir; @@ -268,6 +268,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); + int64_t inputRank = selfTy.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto getOptionalVal = [&](Value val) -> std::optional { if (val.getType().isa()) { @@ -343,17 +347,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().cast(); if (!selfTy) return op.emitError("only ranked tensor types are supported"); - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure( - op, "only constant dim is currently supported"); auto rank = selfTy.getRank(); if (rank == 0) return rewriter.notifyMatchFailure( op, "the rank of tensor must be greater than 0"); + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); dim = toPositiveDim(dim, rank); + if (!isValidDim(dim, rank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + if (selfTy.getShape()[dim] != 1) { if (selfTy.getShape()[dim] == ShapedType::kDynamic) return rewriter.notifyMatchFailure( @@ -396,6 +403,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); + int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + dim = toPositiveDim(dim, inputRank + 1); + if (!isValidDim(dim, inputRank + 1)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim}, options.dimSizeIndexBits); diff --git a/lib/Conversion/TorchToTMTensor/CMakeLists.txt b/lib/Conversion/TorchToTMTensor/CMakeLists.txt index eda9ec7ad247..d05d8277c967 100644 --- a/lib/Conversion/TorchToTMTensor/CMakeLists.txt +++ b/lib/Conversion/TorchToTMTensor/CMakeLists.txt @@ -17,6 +17,7 @@ TorchToTMTensor.cpp MLIRMathDialect TorchMLIRTorchDialect TorchMLIRTMTensorDialect + TorchMLIRTorchUtils ) torch_mlir_target_includes(TorchMLIRTorchToTMTensor) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c841afcdf947..9f19a7345735 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -56,7 +56,7 @@ using namespace mlir::torch::TMTensor; // that these patterns become mostly mechanical associations of // "aten.foo -> linalg.foo". -static Attribute getNumericLimit(PatternRewriter &rewriter, Type elementType, +static TypedAttr getNumericLimit(PatternRewriter &rewriter, Type elementType, bool getMin = true) { auto bitWidth = elementType.getIntOrFloatBitWidth(); if (llvm::isa(elementType)) { @@ -242,6 +242,60 @@ static Value createTMTensorScanOp( return scanOp->getResult(0); } +// Utility function to create a TMTensor::SortOp. +static FailureOr> +createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, + llvm::ArrayRef operands, + llvm::ArrayRef elementTypes, int64_t dimension, + bool isStable, bool isDescending) { + // Step 1. Create TMTensor::SortOp structure. + SmallVector sortResultTypes; + for (Value val : operands) { + sortResultTypes.push_back(val.getType()); + } + ValueRange inputs; + auto sortOp = rewriter.create( + sortOpLoc, sortResultTypes, inputs, operands, + rewriter.getI64IntegerAttr(dimension)); + + // Step 2. Add two arguments for each element type in the SortOp's block. + Region *body = &sortOp.getRegion(); + Block *block = rewriter.createBlock(body); + Location loc = body->getLoc(); + for (Type elementType : elementTypes) { + block->addArguments({elementType, elementType}, + SmallVector(2, loc)); + } + + // Step 3. Create comparison op which will be used as the sorting predicate. + Value compareOp; + if (auto intType = elementTypes[0].dyn_cast()) { + // Case for using arith::CmpIOp. + arith::CmpIPredicate ge = arith::CmpIPredicate::sge; + arith::CmpIPredicate le = arith::CmpIPredicate::sle; + if (intType.isUnsignedInteger()) { + ge = arith::CmpIPredicate::uge; + le = arith::CmpIPredicate::ule; + } + arith::CmpIPredicate predicate = isDescending ? ge : le; + compareOp = rewriter.create( + loc, predicate, block->getArgument(0), block->getArgument(1)); + } else if (elementTypes[0].isa()) { + // Case for using arith::CmpFOp. + arith::CmpFPredicate predicate = + isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE; + compareOp = rewriter.create( + loc, predicate, block->getArgument(0), block->getArgument(1)); + } else { + return rewriter.notifyMatchFailure( + sortOpLoc, "Only Integer and Floating element type expected."); + } + + // Step 4. Create yield op for yielding the sorting predicate. + rewriter.create(loc, compareOp); + return SmallVector(sortOp.getResults()); +} + namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. @@ -360,6 +414,234 @@ class ConvertAtenBincountOp : public OpConversionPattern { }; } // namespace +// """Create a map from each dimension of the input tensor to the +// subspace that dimension corresponds to in the result shape one gets +// from indexing the tensor with the optional index tensors. +// +// Note: Index tensors are first broadcasted to a common shape before +// creating the mapping. So the index of every index tensor will map to +// the same dimensions in the result shape. +// +// For example: +// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))] +// indexBroadcastShapeValue = [6, 7] +// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]} +static SmallVector> +getInputShapeToOutputShapeMap(SmallVector optionalIndices, + SmallVector indexBroadcastShapeValue) { + SmallVector indices; + for (Value index : optionalIndices) { + if (!index.getType().isa()) + indices.push_back(index); + } + + unsigned broadcastRank = indexBroadcastShapeValue.size(); + unsigned numIndexTensors = indices.size(); + int64_t indexOfFirstIndexTensor = -1; + SmallVector> result; + + for (unsigned i = 0; i < optionalIndices.size(); i++) { + if (optionalIndices[i].getType().isa()) { + unsigned val = i; + if (indexOfFirstIndexTensor >= 0) + val += broadcastRank - numIndexTensors; + result.push_back({val}); + } else { + if (indexOfFirstIndexTensor < 0) + indexOfFirstIndexTensor = i; + SmallVector outputIndices; + for (unsigned j = indexOfFirstIndexTensor; + j < (indexOfFirstIndexTensor + broadcastRank); j++) + outputIndices.push_back(j); + result.push_back(outputIndices); + } + } + return result; +} + +static std::tuple, SmallVector> +getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc, + Value input, SmallVector optionalIndices, + SmallVector inputShapeInt, + SmallVector inputShapeValue, + SmallVector indexBroadcastShapeInt, + SmallVector indexBroadcastShapeValue) { + SmallVector result; + SmallVector resultInt; + bool handledIndexTensorSpace = false; + + for (unsigned i = 0; i < inputShapeValue.size(); i++) { + if (optionalIndices[i].getType().isa()) { + result.push_back(inputShapeValue[i]); + resultInt.push_back(inputShapeInt[i]); + } else { + if (!handledIndexTensorSpace) { + handledIndexTensorSpace = true; + for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) { + result.push_back(indexBroadcastShapeValue[j]); + resultInt.push_back(indexBroadcastShapeInt[j]); + } + } + } + } + return std::make_tuple(result, resultInt); +} + +static FailureOr +getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter, + Type indicesDtype, SmallVector optionalIndices, + SmallVector indexBroadcastShapeInt, + SmallVector indexBroadcastShapeValue) { + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + Value input = op.getSelf(); + + SmallVector> shapeMap = + getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue); + + SmallVector inputShapeInt{ + input.getType().cast().getSizes()}; + int64_t inputRank = inputShapeInt.size(); + SmallVector inputShapeValue; + for (unsigned i = 0; i < inputShapeInt.size(); i++) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + inputShapeValue.push_back( + rewriter.createOrFold(loc, input, dim)); + } + + auto finalShapeResult = getIndicesFinalShape( + rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue, + indexBroadcastShapeInt, indexBroadcastShapeValue); + SmallVector finalShapeValue = std::get<0>(finalShapeResult); + SmallVector finalShapeInt = std::get<1>(finalShapeResult); + + Value torchCstNone = rewriter.create(loc); + Value torchCstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value torchCstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + indexBroadcastShapeValue); + + // Calculating index count. + int64_t indexCount = 1; + if (llvm::all_of(finalShapeInt, + [](int64_t shape) { return shape != kUnknownSize; })) { + for (int64_t i : finalShapeInt) + indexCount *= i; + } else { + indexCount = kUnknownSize; + } + + Value indexCountValue = finalShapeValue[0]; + for (unsigned i = 1; i < finalShapeValue.size(); i++) + indexCountValue = + rewriter.create(loc, indexCountValue, finalShapeValue[i]); + + ValueTensorType flattenIndicesType = + ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype); + Value flattenEndDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1)); + + SmallVector broadcastedIndices; + for (unsigned i = 0; i < optionalIndices.size(); i++) { + Value broadcastedIndexTensor; + if (optionalIndices[i].getType().isa()) { + Value torchCstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value inputDim = rewriter.create(loc, input, torchCstDim); + ValueTensorType tensorType = ValueTensorType::get( + context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype); + broadcastedIndexTensor = rewriter.create( + loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim, + /*step=*/torchCstOne, + /*dtype=*/torchCstNone, + /*layout=*/torchCstNone, + /*device=*/torchCstNone, + /*pin_memory=*/torchCstNone); + } else { + ValueTensorType tensorType = ValueTensorType::get( + context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype); + broadcastedIndexTensor = rewriter.create( + loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList); + } + + // spotlight_indices(final_shape, shape_map[i]): + // Turn all values in `final_shape` to `1` except for those with index in + // `indices`. + // for j in range(len(final_shape)): + // if j not in indices: + // final_shape[j] = 1 + // This is equivalent to unsqueezing the index tensor at the dimension `j` + // not in indices. + for (unsigned j = 0; j < finalShapeInt.size(); j++) { + if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) { + Value unsqueezeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(j)); + auto unsqueezedInfo = + unsqueezeTensor(rewriter, op, broadcastedIndexTensor, + /*dim=*/unsqueezeDim); + if (failed(unsqueezedInfo)) { + return rewriter.notifyMatchFailure( + op, "cannot generate unsqueeze tensor op"); + } + broadcastedIndexTensor = *unsqueezedInfo; + } + } + + // Performing broadcast to final shape. + Value broadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + finalShapeValue); + ValueTensorType broadcastTensorType = ValueTensorType::get( + context, llvm::ArrayRef(finalShapeInt), indicesDtype); + broadcastedIndexTensor = rewriter.create( + loc, broadcastTensorType, broadcastedIndexTensor, + broadcastShapeTorchList); + + // Flattening the tensor. + broadcastedIndexTensor = rewriter.create( + loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero, + flattenEndDim); + + broadcastedIndices.push_back(broadcastedIndexTensor); + } + + // Stacking broadcasted indices. + Value scatterIndices; + // The operation torch.stack([a, b], dim=0) is decomposed into: + // torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0) + // Unsqueeze all tensors before concatenating. + SmallVector unsqueezedIndexTensors; + for (Value tensor : broadcastedIndices) { + auto unsqueezedInfo = + unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero); + if (failed(unsqueezedInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor op"); + } + unsqueezedIndexTensors.push_back(*unsqueezedInfo); + } + + BaseTensorType unsqueezedTensorType = + unsqueezedIndexTensors[0].getType().cast(); + Value concatIndicesTorchList = rewriter.create( + loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors); + ValueTensorType concatIndicesType = ValueTensorType::get( + context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype); + scatterIndices = rewriter.create( + loc, concatIndicesType, concatIndicesTorchList, torchCstZero); + + ValueTensorType transposedIndicesType = ValueTensorType::get( + context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype); + scatterIndices = rewriter.create( + loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne); + return scatterIndices; +} + namespace { class ConvertAten_IndexPutImplOp : public OpConversionPattern { @@ -376,9 +658,15 @@ class ConvertAten_IndexPutImplOp Value values = adaptor.getValues(); RankedTensorType inputType = input.getType().cast(); RankedTensorType valuesType = values.getType().cast(); + int64_t inputRank = inputType.getRank(); + auto valuesTensorType = op.getValues().getType().cast(); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); + if (!valuesTensorType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: the values tensor type must have sizes."); + // The unsafe should be either `False` or `none`. if (!op.getUnsafe().getType().isa()) { bool unsafe; @@ -401,51 +689,194 @@ class ConvertAten_IndexPutImplOp return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); - SmallVector indicesList; - getListConstructElements(adaptor.getIndices(), indicesList); + SmallVector optionalIndicesList; + getListConstructElements(op.getIndices(), optionalIndicesList); // The size of the list of the index tensors should not be greater than the // input rank. - if ((int64_t)indicesList.size() > inputType.getRank()) + if ((int64_t)optionalIndicesList.size() > inputRank) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); - // TODO: Add support for cases with indices list size not equal to 1. - if (indicesList.size() != 1) + Value torchCstNone = rewriter.create(loc); + unsigned sizeOptionalIndicesList = optionalIndicesList.size(); + SmallVector nonNoneIndexTensorDim; + unsigned numNonNoneIndices; + + if (sizeOptionalIndicesList == 0) + return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); + + for (unsigned i = 0; i < optionalIndicesList.size(); i++) { + if (!optionalIndicesList[i].getType().isa()) { + nonNoneIndexTensorDim.push_back(i); + } + } + + numNonNoneIndices = nonNoneIndexTensorDim.size(); + if (numNonNoneIndices > 2) { + return rewriter.notifyMatchFailure( + op, "unimplemented: non none index tensors less than or equal to 2 " + "supported only"); + } else if (numNonNoneIndices == 2 && + nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) { return rewriter.notifyMatchFailure( - op, "Unimplemented: Indices list size != 1"); - Value indexTensor = indicesList[0]; + op, "unimplemented: case of 2 non none index tensors is supported " + "only when both the tensors are along consecutive dimensions"); + } - if (indexTensor.getType().isa()) - return rewriter.notifyMatchFailure(op, "Index tensor must not be None."); + // Padding the indices list with none values. + if (sizeOptionalIndicesList < inputRank) { + for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++) + optionalIndicesList.push_back(torchCstNone); + } - // Creating a tm_tensor.scatter op with the following mapping: - // 1.) Index tensor from the `indicesList` maps to the indices in scatter - // op. Index tensor is expanded from 1-d to 2-d, and its element type is set - // to i32 as required for the scatter op. - // 2.) `values` is mapped to `updates` in scatter op. - // 3.) `input` is mapped to `original` in scatter op. - std::optional indexTensorRank = getTensorRank(indexTensor); - if (!indexTensorRank || *indexTensorRank != 1) + SmallVector indexBroadcastShapeInt{ + optionalIndicesList[nonNoneIndexTensorDim[0]] + .getType() + .cast() + .getSizes()}; + SmallVector indexBroadcastShapeValue; + if (numNonNoneIndices == 2) { + computeBroadcastShape(rewriter, loc, + optionalIndicesList[nonNoneIndexTensorDim[0]], + optionalIndicesList[nonNoneIndexTensorDim[1]], + indexBroadcastShapeInt, indexBroadcastShapeValue); + } else { + // It means there's only one index tensor and broadcast shape is same as + // that index tensor' shape. + for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + indexBroadcastShapeValue.push_back(rewriter.createOrFold( + loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim)); + } + } + + Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]] + .getType() + .cast() + .getDtype(); + + // This implementation is done to get the scatter indices: + + // def get_broadcast_shape(tensors): + // return list(torch.broadcast_tensors(*tensors)[0].shape) + + // def get_input_shape_to_output_shape_map(optional_index_tensors: + // list[Optional[torch.Tensor]]): + // index_tensors = list(filter(lambda x: x is not None, + // optional_index_tensors)) broadcast_rank = + // len(get_broadcast_shape(index_tensors)) num_of_index_tensors = + // len(index_tensors) index_of_first_index_tensor: Optional[int] = None + // result = {} + // for i, index in enumerate(optional_index_tensors): + // if index is None: + // val = i + // if index_of_first_index_tensor is not None: + // val += broadcast_rank - num_of_index_tensors + // result[i] = [val] + // else: + // if index_of_first_index_tensor is None: + // index_of_first_index_tensor = i + // output_indices = list(range(index_of_first_index_tensor, + // index_of_first_index_tensor + + // broadcast_rank)) + // result[i] = output_indices + // return result + + // def spotlight_indices(shape, indices: list[int]): + // """Turn all values in `shape` to `1` except for those with index in + // `indices`.""" shape = shape.copy() for i in range(len(shape)): + // if i not in indices: + // shape[i] = 1 + // return shape + + // def get_final_shape(input, optional_index_tensors: + // list[Optional[torch.Tensor]]): + // index_tensors = list(filter(lambda x: x is not None, + // optional_index_tensors)) index_tensors_broadcast_shape = + // get_broadcast_shape(index_tensors) result = [] + // handled_index_tensor_space = False + // for e, i in enumerate(input.shape): + // if optional_index_tensors[e] is None: + // result.append(i) + // else: + // if not handled_index_tensor_space: + // handled_index_tensor_space = True + // result += index_tensors_broadcast_shape + // return result + + // def get_scatter_indices(input, optional_index_tensors: + // list[Optional[torch.Tensor]]): + // assert len(input.size()) == len(optional_index_tensors), "Pad indices + // with None" shape_map = + // get_input_shape_to_output_shape_map(optional_index_tensors) + // index_tensors = list(filter(lambda x: x is not None, + // optional_index_tensors)) index_tensors_broadcast_shape = + // get_broadcast_shape(index_tensors) final_shape = + // get_final_shape(input, optional_index_tensors) + + // broadcasted_index_tensors = [] + // for e, optional_index_tensor in enumerate(optional_index_tensors): + // if optional_index_tensor is None: + // tensor_to_broadcast = torch.arange(0, input.size(e)) + // else: + // tensor_to_broadcast = + // optional_index_tensor.broadcast_to(index_tensors_broadcast_shape) + + // broadcasted_index_tensor = \ + // tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\ + // .broadcast_to(final_shape)\ + // .flatten() + // broadcasted_index_tensors.append(broadcasted_index_tensor) + + // return torch.stack(broadcasted_index_tensors, dim=0).t() + + auto scatterIndicesInfo = + getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList, + indexBroadcastShapeInt, indexBroadcastShapeValue); + if (failed(scatterIndicesInfo)) { return rewriter.notifyMatchFailure( - op, "unimplemented: index tensor with rank != 1 is not supported"); - auto indexTensorType = indexTensor.getType().cast(); - int64_t indexTensorSize = indexTensorType.getSizes()[0]; - SmallVector expandedIndexTensorSizes{indexTensorSize, 1}; - ValueTensorType expandedIndexTensorType = - ValueTensorType::get(context, llvm::ArrayRef(expandedIndexTensorSizes), - indexTensorType.getDtype()); - Value torchCstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value expandedIndexTensor = rewriter.create( - loc, expandedIndexTensorType, indexTensor, torchCstOne); + op, "cannot generate scatter indices for index put op"); + } + Value indexTensor = *scatterIndicesInfo; + + // Flattening the values tensor. + Value torchCstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value flattenedValuesTensorLastDim = rewriter.create( + loc, + rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1)); + SmallVector valuesShapeInt{valuesTensorType.getSizes()}; + int64_t valuesCount = 1; + if (llvm::all_of(valuesShapeInt, + [](int64_t shape) { return shape != kUnknownSize; })) { + for (int64_t i : valuesShapeInt) + valuesCount *= i; + } else { + valuesCount = kUnknownSize; + } + auto flattenedValuesTensorType = ValueTensorType::get( + context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype()); + Value flattenedValuesTensor = rewriter.create( + loc, flattenedValuesTensorType, op.getValues(), torchCstZero, + flattenedValuesTensorLastDim); + values = typeConverter->materializeTargetConversion( + rewriter, loc, + typeConverter->convertType(flattenedValuesTensor.getType()), + flattenedValuesTensor); // `TMTensor::ScatterOp` expects indices of element type i32. Value indices = convertTensorToDtype( - rewriter, loc, expandedIndexTensor, + rewriter, loc, indexTensor, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); + // Creating a tm_tensor.scatter op with the following mapping: + // 1.) Index tensor from the `indicesList` maps to the indices in scatter + // op. + // 2.) `values` is mapped to `updates` in scatter op. + // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( rewriter, loc, values, indices, input, /*uniqueIndices=*/false, @@ -753,7 +1184,7 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::MEAN) { SmallVector selfShape = getTensorSizes(rewriter, loc, adaptor.getSelf()); - Attribute initAttr; + TypedAttr initAttr; if (llvm::isa(srcType.getElementType())) { initAttr = rewriter.getFloatAttr(srcType.getElementType(), 1); } else if (llvm::isa(srcType.getElementType())) { @@ -789,13 +1220,13 @@ class ConvertAtenScatterReduceTwoOp } else if (reduceEnum == torch_upstream::ReductionType::MAX) { // Set the values in the input tensor to the smallest element of that // type - auto minAttr = getNumericLimit(rewriter, srcType.getElementType(), + TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), /*getMin=*/true); normalizationValue = rewriter.create(loc, minAttr); } else if (reduceEnum == torch_upstream::ReductionType::MIN) { // Set the values in the input tensor to the largest element of that // type - auto maxAttr = getNumericLimit(rewriter, srcType.getElementType(), + TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), /*getMin=*/false); normalizationValue = rewriter.create(loc, maxAttr); } @@ -920,6 +1351,93 @@ class ConvertAtenScatterReduceTwoOp }; } // namespace +namespace { +class ConvertAtenSortOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSortOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + // Step 1. Fetch Input to sort. + Value inputTensor = adaptor.getSelf(); + auto inputType = inputTensor.getType().cast(); + unsigned inputRank = inputType.getRank(); + + // Step 2. Fetch dimension to perform sort in. + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + + // Step 3. Fetch the order of sorting. + bool descending; + if (!matchPattern(op.getDescending(), m_TorchConstantBool(&descending))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant descending value is supported"); + + // Step 4. Form a RankedTensorType with same shape as that of the input's + // but with elemental type i64. + RankedTensorType indicesType = + RankedTensorType::get(inputType.getShape(), rewriter.getI64Type()); + + // Step 5. Generate indices tensor. + SmallVector dynDims; + for (unsigned i = 0; i < inputType.getRank(); i++) { + if (inputType.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, inputTensor, i)); + } + } + Value initEmptyTensor = rewriter.create( + loc, inputType.getShape(), rewriter.getI64Type(), dynDims); + + SmallVector indexingMaps = { + AffineMap::getMultiDimIdentityMap(inputRank, op.getContext())}; + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value indicesTensor = + rewriter + .create( + loc, initEmptyTensor.getType(), ValueRange{}, initEmptyTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = b.create(loc, dim); + index = castIndexToInt64(b, loc, index); + b.create(loc, index); + }) + .getResult(0); + + // Step 6. Create TMTensor::SortOp. + SmallVector operands; + operands.push_back(inputTensor); + operands.push_back(indicesTensor); + SmallVector elementTypes; + elementTypes.push_back(inputType.getElementType()); + elementTypes.push_back(indicesType.getElementType()); + + // The default value for aten.sort op's `stable` parameter is `false`. + // Refer: https://pytorch.org/docs/stable/generated/torch.sort.html + FailureOr> sortOpValues = + createTMTensorSortOp(rewriter, loc, operands, elementTypes, + /*dimension=*/dim, /*isStable=*/false, + /*isDescending=*/descending); + if (failed(sortOpValues)) + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); + + auto sortOpVal = *sortOpValues; + rewriter.replaceOp(op, sortOpVal); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -1014,6 +1532,8 @@ class ConvertTorchToTMTensor context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index e1f5142bd924..909ee3bcba26 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa MLIRIR MLIRPass MLIRTosaDialect + TorchMLIRConversionUtils TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index dfe655fc915e..71219908e720 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -717,6 +717,12 @@ class ConvertAtenMultipleDimsReductionOp return rewriter.notifyMatchFailure(op, "non-const dim parameter unsupported"); int64_t N = reduceDims.size(); + int64_t inputRank = adaptor.getSelf().getType().template cast().getRank(); + for (unsigned i=0; i().getRank(); + reduceDim = toPositiveDim(reduceDim, inputRank); + if (!isValidDim(reduceDim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto reduceDimsType = RankedTensorType::get({1}, rewriter.getI64Type()); reduceDimsAttr = DenseIntElementsAttr::get(reduceDimsType, llvm::ArrayRef({reduceDim})); @@ -806,6 +816,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&reduceDim))) { // NoneType indicates reduce on all dims reduceDim = -1; + } else { + int64_t inputRank = selfTy.getRank(); + reduceDim = toPositiveDim(reduceDim, inputRank); + if (!isValidDim(reduceDim, inputRank)) + return rewriter.notifyMatchFailure(op, "reduce dim is statically invalid"); } bool keepDim = false; @@ -3171,8 +3186,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t end; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) return rewriter.notifyMatchFailure(op, "end must be a Scalar constant"); + // support for end < 0 + end = toPositiveDim(end, selfType.getShape()[dim]); - // FIXME: add support for start/end < 0 and end < start + // FIXME: add support for start < 0 and end < start if (end < start) return rewriter.notifyMatchFailure(op, "Currently unsupported: end < start"); @@ -3499,7 +3516,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Support for multiple index auto index = indexTensors[0]; auto indexTorch = tensorsTorchType[0]; - // TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None)) + // TODO add support for none index input like torch.ops.aten.index(x, (None, + // index1, index2, None)) if (indexTorch.getType().isa()) return rewriter.notifyMatchFailure( op, "Only list ranked tensor types index are supported"); @@ -3540,6 +3558,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto selfType = adaptor.getSelf().getType().dyn_cast(); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenWhereSelfOp op, OpAdaptor adaptor, @@ -3563,6 +3597,32 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLeTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto selfType = adaptor.getSelf().getType().dyn_cast(); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + auto otherType = adaptor.getOther().getType().dyn_cast(); + if (!otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types condition are currently supported"); + + auto outType = getTypeConverter()->convertType(op.getType()); + + auto greaterOp = rewriter.create( + op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther()); + + rewriter.replaceOpWithNewOp(op, outType, + greaterOp.getOutput()); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenClampOp op, OpAdaptor adaptor, @@ -3658,13 +3718,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Only supports integer operand type, because for the floating point operand // type result tensor has to be of type `f64` which is not supported in the // tosa. - int64_t initValue; - if (!matchPattern(op.getA(), m_TorchConstantInt(&initValue))) - return rewriter.notifyMatchFailure( - op, "unimplemented: input should be a torch constant int"); + double doubleValue; + auto isDouble = matchPattern(op.getA(), m_TorchConstantFloat(&doubleValue)); + int64_t intValue; + auto isInt = matchPattern(op.getA(), m_TorchConstantInt(&intValue)); + if (!isDouble && !isInt) + return rewriter.notifyMatchFailure(op, + "Unable to extract the scalar constant"); + + auto outElemTy = resultType.getElementType(); + if (outElemTy.isa()) { + rewriter.replaceOpWithNewOp(op, resultType, DenseElementsAttr::get(resultType, {intValue})); + } else if (outElemTy.isF64()) { + rewriter.replaceOpWithNewOp(op, resultType, DenseElementsAttr::get(resultType, {doubleValue})); + } - DenseElementsAttr constAttr = DenseElementsAttr::get(resultType, {initValue}); - rewriter.replaceOpWithNewOp(op, resultType, constAttr); return success(); } @@ -3772,6 +3840,58 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRemainderScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Value self = adaptor.getSelf(); + auto selfTy = self.getType().template cast(); + + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Remainder"); + + auto outType = + getTypeConverter()->convertType(op.getType()).template cast(); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); + + Value otherTensor; + Value other = op.getOther(); + if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, + outElemTy, {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Remainder operation"); + + if (selfTy.getElementType() != outElemTy) + self = rewriter.create(op.getLoc(), outType, self); + + auto divTensor = self; + // tosa::DivOp only supports int + if (outElemTy.isa()) { + auto otherTensorReciprocal = rewriter.create( + op.getLoc(), otherTensor.getType(), otherTensor); + divTensor = rewriter.create( + op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + divTensor = rewriter.create(op.getLoc(), outType, divTensor); + } else { + divTensor = + rewriter.create(op.getLoc(), outType, self, otherTensor); + } + + auto mulTensor = + rewriter.create(op.getLoc(), outType, otherTensor, divTensor, + /*shift=*/0); + rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3793,14 +3913,16 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { static int64_t getOutputDim(int64_t inputDim, int64_t kernelDim, int64_t stride, int64_t padBefore, - int64_t padAfter, int64_t dilation) { + int64_t padAfter, int64_t dilation, + bool ceilMode = false) { if (inputDim == kUnknownSize) { return kUnknownSize; } else { - return ( - (inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1) / - stride + - 1); + int64_t dimSize = + inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; + if (ceilMode && (dimSize % stride != 0)) + return dimSize / stride + 2; + return dimSize / stride + 1; } } @@ -3977,17 +4099,23 @@ template static Type getOutputTypeForNonAdaptivePoolingOp( RankedTensorType inputTy, SmallVectorImpl &kernelSize, SmallVectorImpl &strideArray, SmallVectorImpl &padArray, - SmallVectorImpl &dilationArray) { + SmallVectorImpl &dilationArray, bool ceilMode = false) { auto inputShape = makeShapeTorchCompatible(inputTy.getShape()); auto inputRank = inputTy.getRank(); auto inputElemTy = inputTy.getElementType(); int64_t outputHDim = ConvertAtenPoolingBaseOp::getOutputDim( inputShape[inputRank - 2], kernelSize[0], strideArray[0], padArray[0], - padArray[0], dilationArray[0]); + padArray[0], dilationArray[0], ceilMode); int64_t outputWDim = ConvertAtenPoolingBaseOp::getOutputDim( inputShape[inputRank - 1], kernelSize[1], strideArray[1], padArray[1], - padArray[1], dilationArray[1]); + padArray[1], dilationArray[1], ceilMode); + padArray[0] = (outputHDim - 1) * strideArray[0] + + dilationArray[0] * kernelSize[0] - dilationArray[0] + 1 - + padArray[0] * 2 - inputShape[inputRank - 2]; + padArray[1] = (outputWDim - 1) * strideArray[1] + + dilationArray[0] * kernelSize[1] - dilationArray[0] + 1 - + padArray[1] * 2 - inputShape[inputRank - 1]; SmallVector outputShape; if (inputRank > 3) outputShape.push_back(inputShape[0]); @@ -4023,30 +4151,38 @@ static LogicalResult getOutputTypeAndPoolingParameters( m_TorchListOfConstantInts(kernelSizeInts))) return rewriter.notifyMatchFailure( op, "Non-const kernel_size for pooling op unsupported"); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure( op, "Non-const stride for pooling op unsupported"); + // If `stride` is not specified by the user, it is assigned the value of empty + // list during import. For such a case, the stride value is the kernel size. + // See: + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + if (strideInts.empty()) + strideInts.assign(kernelSizeInts); + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); + SmallVector padArr = {paddingInts[0], paddingInts[0], + paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); stride = rewriter.getDenseI64ArrayAttr(strideInts); - pad = rewriter.getDenseI64ArrayAttr( - {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}); - // FIXME: add ceil_mode support. bool ceilMode; if (!matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode))) return rewriter.notifyMatchFailure( op, "only support constant bool ceil_mode for pooling op"); - if (ceilMode) - return rewriter.notifyMatchFailure( - op, "only support ceil_mode equals to False for pooling op"); outputTy = getOutputTypeForNonAdaptivePoolingOp( - inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray); - + inputTy, kernelSizeInts, strideInts, paddingInts, dilationArray, + ceilMode); + padArr[1] = padArr[1] + paddingInts[0]; + padArr[3] = padArr[3] + paddingInts[1]; + pad = rewriter.getDenseI64ArrayAttr( + {padArr[0], padArr[1], padArr[2], padArr[3]}); return success(); } @@ -4136,9 +4272,16 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? - if (!op.getLayout().getType().template isa()) - return rewriter.notifyMatchFailure(op, - "Only default layout is supported"); + // The layout arg should be either `none` or `0` i.e. strided. + if (!op.getLayout().getType().template isa()) { + int64_t tensorLayout; + if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) + return rewriter.notifyMatchFailure( + op, "The layout arg should be either `none` or `0` i.e. strided."); + else if (tensorLayout != torch_upstream::Layout::Strided) + return rewriter.notifyMatchFailure( + op, "The layout arg should be either `none` or `0` i.e. strided."); + } bool pinMemory; if (!op.getPinMemory().getType().template isa() && @@ -4291,6 +4434,75 @@ class ConvertAtenCloneOp : public OpConversionPattern { } }; +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenConstantPadNdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + auto selfTy = self.getType().cast(); + auto selfElemTy = selfTy.getElementType(); + int64_t rank = selfTy.getRank(); + + // START the code snippet from + // lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: + // ConvertAtenConstantPadNdOp) Pattern match against the op's original + // operands, because otherwise we will get the lowered version of the operands + // which is harder to pattern match. + SmallVector padInts; + if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (rank < 0 || padRank > (uint64_t)rank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + // Initialize low/high paddings with 0 for all the dims. + SmallVector lowPadding(/*Size=*/rank, /*Value=*/0); + SmallVector highPadding(/*Size=*/rank, /*Value=*/0); + // Add the requested padding - note op.pad() is highest dim first ordered + // pairs of low,high. + for (uint64_t i = 0; i < padRank; ++i) { + lowPadding[rank - i - 1] = padInts[i * 2]; + highPadding[rank - i - 1] = padInts[i * 2 + 1]; + } + // END the code snippet from + // lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: + // ConvertAtenConstantPadNdOp) + + llvm::SmallVector translatePadsList; + + for (unsigned int i = 0; i < rank; i++) { + translatePadsList.push_back(lowPadding[i]); + translatePadsList.push_back(highPadding[i]); + } + + DenseElementsAttr paddingAttr = DenseIntElementsAttr::get( + RankedTensorType::get({rank, 2}, rewriter.getI64Type()), + translatePadsList); + + Value padsList1 = rewriter.create( + loc, paddingAttr.getType(), paddingAttr); + + Value padValue = adaptor.getValue(); + Operation *padOp = padValue.getDefiningOp(); + padValue = padOp->getOperand(0); + + Value padTensor; + if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue, + padTensor, selfElemTy, {}))) + return rewriter.notifyMatchFailure( + op, "Pad value needs to be a scalar constant for conversion to " + "TOSA pad operation"); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, padsList1, + padTensor); + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -4508,12 +4720,16 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); INSERT_ATENOP_PATTERN(AtenIndexTensorOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenLeTensorOp); INSERT_ATENOP_PATTERN(AtenClampOp); INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenCopyOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 1583506305f3..09be73436eb6 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -235,6 +235,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { (src.isF32() && dest.isF64()) || (src.isF64() && dest.isF32()) || (src.isF32() && dest.isInteger(8)) || + (src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(1))) { return success(); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 906cc3c4458d..1f6a889b5567 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -197,7 +197,7 @@ Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { // Creates a constant of type `elemType` with value `val`. Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { - Attribute attr = {}; + TypedAttr attr = {}; if (elemType.isa()) attr = b.getFloatAttr(elemType, val); if (elemType.isa()) @@ -324,6 +324,106 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, llvm_unreachable("convertScalarToDtype should handle all the types"); } +Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, + Value torchOptionalInt, Value builtinInt, + Value defaultValue, Value dimSize) { + if (torchOptionalInt.getType().isa()) + return defaultValue; + auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); + Value positiveDim = + toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); + // positiveDim < 0 ? 0 : positiveDim + Value cst0 = rewriter.create( + loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); + Value predDimSltZero = rewriter.create( + loc, arith::CmpIPredicate::slt, positiveDim, cst0); + Value atLeastZero = + rewriter.create(loc, predDimSltZero, cst0, positiveDim); + // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero + Value sgtDimSize = rewriter.create( + loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); + Value boundedByDimSize = rewriter.create( + loc, sgtDimSize, dimSizeAsInt, atLeastZero); + + return castIntToIndex(rewriter, loc, boundedByDimSize); +} + +// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If +// yes, then computes the final broadcast shape. +void computeBroadcastShape(ConversionPatternRewriter &rewriter, Location loc, + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue) { + SmallVector shapeA{ + inputA.getType().cast().getSizes()}; + SmallVector shapeB{ + inputB.getType().cast().getSizes()}; + unsigned rankA = shapeA.size(); + unsigned rankB = shapeB.size(); + unsigned minRank = rankA > rankB ? rankB : rankA; + // Check whether the shapes of the tensors are broadcastable or not. + // Two tensors are “broadcastable” if the following rules hold: + // 1.) Each tensor has at least one dimension. + // 2.) When iterating over the dimension sizes, starting at the trailing + // dimension, the dimension sizes must either be equal, one of them is 1, or + // one of them does not exist. + for (unsigned i = 0; i < minRank; i++) { + Value sizeDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankA - i - 1)); + Value sizeDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankB - i - 1)); + Value sizeInputA = + rewriter.createOrFold(loc, inputA, sizeDimA); + Value sizeInputB = + rewriter.createOrFold(loc, inputB, sizeDimB); + Value torchCstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cmpSizeAEqualsSizeB = + rewriter.create(loc, sizeInputA, sizeInputB); + Value cmpSizeAEqualsOne = + rewriter.create(loc, sizeInputA, torchCstOne); + Value cmpSizeBEqualsOne = + rewriter.create(loc, sizeInputB, torchCstOne); + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), + SmallVector{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, + cmpSizeBEqualsOne}); + Value cmp = rewriter.create(loc, anyBoolOpList); + rewriter.create( + loc, cmp, "tensors are not broadcast compatible"); + } + // If we reach here then it means both the shapes are broadcast compatible. + resultShape = rankA >= rankB ? shapeA : shapeB; + Value shapeTensor = rankA >= rankB ? inputA : inputB; + for (unsigned i = 0; i < resultShape.size(); i++) { + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + resultShapeValue.push_back( + rewriter.createOrFold(loc, shapeTensor, sizeDim)); + } + + unsigned resultRank = resultShape.size(); + for (unsigned i = 0; i < minRank; i++) { + Value sizeDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankA - i - 1)); + Value sizeDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(rankB - i - 1)); + Value sizeInputA = + rewriter.createOrFold(loc, inputA, sizeDimA); + Value sizeInputB = + rewriter.createOrFold(loc, inputB, sizeDimB); + resultShapeValue[resultRank - i - 1] = + rewriter.create(loc, sizeInputA, sizeInputB); + if (shapeA[rankA - i - 1] == kUnknownSize || + shapeB[rankB - i - 1] == kUnknownSize) { + resultShape[resultRank - i - 1] = kUnknownSize; + } else { + resultShape[resultRank - i - 1] = + std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); + } + } +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index df75d1b64387..28506d6eab9b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -409,7 +409,7 @@ static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.inlineBlockBefore(block, op, blockArgs); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } @@ -797,6 +797,48 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) { return getOperand(0); } +void AtenToDtypeLayoutOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + // `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory + // is false + patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) { + // The pin_memory arg should be either constant `False` or `none`. + if (!op.getPinMemory().getType().isa()) { + bool pinMemory; + if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) + return failure(); + else if (pinMemory) + return failure(); + } + + // The layout arg should be either `none` or `0` i.e. strided. + if (!op.getLayout().getType().isa()) { + int64_t tensorLayout; + if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) + return failure(); + else if (tensorLayout != torch_upstream::Layout::Strided) + return failure(); + } + + if (op.getDevice().getType().isa()) { + // The device arg is `none`. Rewrite to to.dtype. + AtenToDtypeOp toDtype = rewriter.create( + op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), + op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOp(op, toDtype->getResults()); + } else { + // The device arg is not `none`. Rewrite to to.device. + AtenToDeviceOp toDevice = rewriter.create( + op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), + op.getDtype(), op.getNonBlocking(), op.getCopy(), + op.getMemoryFormat()); + rewriter.replaceOp(op, toDevice->getResults()); + } + + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenViewOp //===----------------------------------------------------------------------===// @@ -812,6 +854,15 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { return getOperand(0); } +//===----------------------------------------------------------------------===// +// PrimsViewOfOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) { + // Always fold the op with its only input operand. + return getOperand(); +} + //===----------------------------------------------------------------------===// // AtenDimOp //===----------------------------------------------------------------------===// @@ -1137,6 +1188,8 @@ OpFoldResult AtenSizeIntOp::fold(FoldAdaptor adaptor) { return nullptr; ArrayRef sizes = type->getSizes(); dim = toPositiveDim(dim, sizes.size()); + if (!isValidDim(dim, sizes.size())) + return nullptr; return IntegerAttr::get(IntegerType::get(getContext(), 64), sizes[dim]); } @@ -1257,6 +1310,12 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, return nullptr; } +//===----------------------------------------------------------------------===// +// AtenDetachOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } + //===----------------------------------------------------------------------===// // AtenNeIntOp //===----------------------------------------------------------------------===// @@ -2323,6 +2382,20 @@ OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// PrimDeviceOp +//===----------------------------------------------------------------------===// + +void PrimDeviceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](PrimDeviceOp op, PatternRewriter &rewriter) { + // Device information isn't relevant to torch-mlir, just replace it with + // "cpu". + rewriter.replaceOpWithNewOp(op, "cpu"); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 5cd48b3a1c56..712040cf7347 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -86,9 +86,25 @@ bool Torch::isValidSubtype(Type subtype, Type type) { return false; } + // `type` must not have more static shape information than `subtype`. + auto isSubsizes = [](BaseTensorType type, BaseTensorType subtype) -> bool { + auto typeSizes = type.getSizes(); + auto subtypeSizes = subtype.getSizes(); + if (typeSizes.size() != subtypeSizes.size()) { + return false; + } + for (auto t : llvm::zip(typeSizes, subtypeSizes)) { + if (std::get<0>(t) != Torch::kUnknownSize && + std::get<0>(t) != std::get<1>(t)) { + return false; + } + } + return true; + }; + if (typeTensorType.hasSizes() && (!subtypeTensorType.hasSizes() || - typeTensorType.getSizes() != subtypeTensorType.getSizes())) { + !isSubsizes(typeTensorType, subtypeTensorType))) { return false; } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index aed70815431f..35b82e614dd9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -319,6 +319,90 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %8 : !torch.int\n" " }\n" +" func.func @__torch__.torch.jit._shape_functions.squeeze_dims(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" } else {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %8 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.append.t %3, %8 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %8 = torch.aten.__getitem__.t %3, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.le.int %9, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" %12 = torch.aten.neg.int %11 : !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.lt.int %8, %12 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %20 = torch.aten.gt.int %8, %13 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" }\n" +" %16 = torch.aten.__not__ %15 : !torch.bool -> !torch.bool\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.lt.int %8, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.int) {\n" +" %20 = torch.aten.add.int %8, %11 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %20 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %8 : !torch.int\n" +" }\n" +" %19 = torch.aten._set_item.t %3, %arg2, %18 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %7, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %8 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" %10 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list, !torch.int -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" %12 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.append.t %6, %12 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %6, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @__torch__.torch.jit._shape_functions.unsqueeze(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" @@ -5668,6 +5752,98 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %1, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.lt.int %int0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %15 = torch.aten.le.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.le.int %2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %15 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %17 : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" +" %16 = torch.aten.len.t %15 : !torch.list -> !torch.int\n" +" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.eq.int %19, %9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %18 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14 = torch.prim.If %13 -> (!torch.list) {\n" +" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" return %14 : !torch.list\n" +" }\n" " func.func @__torch__.torch.jit._shape_functions.broadcast_three(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -5942,6 +6118,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6283,6 +6463,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %2 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.var_mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.tuple, list> {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.var_mean.dim\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.var_mean\"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.tuple, list> {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " %1 = torch.prim.ListConstruct : () -> !torch.list\n" @@ -6306,56 +6498,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %none = torch.constant.none\n" -" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.list) {\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" torch.prim.If.yield %2 : !torch.list\n" +" %str = torch.constant.str \"AssertionError: getting num_classes from tensor contents is not supported\"\n" +" %int-1 = torch.constant.int -1\n" +" %0 = torch.aten.ne.int %arg1, %int-1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" -" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" -" %3 = func.call @__torch__._reduce_along_dim(%arg0, %2, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" -" torch.prim.If.yield %3 : !torch.list\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" -" return %1 : !torch.list\n" -" }\n" -" func.func @__torch__._reduce_along_dim(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" -" %true = torch.constant.bool true\n" -" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg1, %0, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %4 = torch.prim.ListConstruct %int9223372036854775807, %3 : (!torch.int, !torch.int) -> !torch.list\n" -" %5 = torch.prim.min.self_int %4 : !torch.list -> !torch.int\n" -" torch.prim.Loop %5, %true, init() {\n" -" ^bb0(%arg3: !torch.int):\n" -" %6 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.eq.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If %arg2 -> () {\n" -" %8 = torch.aten.append.t %2, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield\n" -" } else {\n" -" %8 = torch.aten.append.t %2, %6 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" +" %1 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" %2 = torch.aten.add.t %arg0, %1 : !torch.list, !torch.list -> !torch.list\n" " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.any.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" -" return %0 : !torch.list\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" -" %0 = call @__torch__._reduce_along_dim(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" -" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %1 : !torch.tuple, list>\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.amax\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %none = torch.constant.none\n" @@ -6378,6 +6548,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.movedim.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" %1 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" %2 = call @__torch__.torch.jit._shape_functions.movedim(%arg0, %0, %1) : (!torch.list, !torch.list, !torch.list) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.movedim.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.transpose.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6420,102 +6600,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.bmm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" -" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: bmm only supports 3D tensors\"\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %11 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %13 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.baddbmm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list {\n" -" %str = torch.constant.str \"AssertionError: mismatching contracting dimension\"\n" -" %str_0 = torch.constant.str \"AssertionError: mismatching batch dimension\"\n" -" %none = torch.constant.none\n" -" %str_1 = torch.constant.str \"AssertionError: baddbmm only supports 3D tensors\"\n" -" %int3 = torch.constant.int 3\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %6 = torch.aten.eq.int %4, %5 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.eq.int %7, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %10 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %11 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.prim.ListConstruct %10, %11, %12 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %13 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.bmm(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.embedding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.embedding(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.int, !torch.bool, !torch.bool) -> !torch.list\n" @@ -6872,6 +6962,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randint\"(%arg0: !torch.int, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7020,6 +7113,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.squeeze(%arg0, %arg1) : (!torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.squeeze\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.squeeze_dims(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.prims.view_of\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.view_of\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prim.NumToTensor.Scalar\"(%arg0: !torch.float) -> !torch.list {\n" " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" @@ -7074,23 +7178,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"k ({}) is too big for dimension {} of size {}\"\n" -" %0 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" %4 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.format(%str_0, %arg1, %arg2, %4) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" -" %6 = torch.aten.add.str %str, %5 : !torch.str, !torch.str -> !torch.str\n" -" torch.prim.RaiseException %6, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten._set_item.t %arg0, %arg2, %arg1 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %3 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %3 : !torch.tuple, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.topk\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" @@ -7116,6 +7211,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" +" %int10 = torch.constant.int 10\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.ListConstruct %int11, %int5, %int9, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %0#1 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %7 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %8 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %8 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -7123,19 +7248,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.list, !torch.optional>) -> !torch.tuple, list, list>\n" " return %0 : !torch.tuple, list, list>\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.list) -> !torch.tuple, list, list> {\n" -" %none = torch.constant.none\n" -" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" -" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.list, !torch.list, !torch.optional>) -> !torch.tuple, list, list>\n" -" return %1 : !torch.tuple, list, list>\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" -" return %arg0 : !torch.list\n" +" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sort\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.narrow\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %int1 = torch.constant.int 1\n" " %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int\n" @@ -7240,98 +7370,8 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" -" %int-1 = torch.constant.int -1\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %false = torch.constant.bool false\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %2 = torch.aten.lt.int %int0, %0 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %15 = torch.aten.le.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.le.int %1, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" %15 = torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %7 = torch.prim.If %6 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %16 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %17 = torch.aten.eq.int %15, %16 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %17 : !torch.bool\n" -" }\n" -" torch.prim.If %7 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %8 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.prim.ListConstruct : () -> !torch.list\n" -" %10 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" -" %15 = torch.prim.unchecked_cast %arg2 : !torch.optional> -> !torch.list\n" -" %16 = torch.aten.len.t %15 : !torch.list -> !torch.int\n" -" %17 = torch.aten.eq.int %16, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" %18 = torch.prim.If %17 -> (!torch.bool) {\n" -" %19 = torch.aten.__getitem__.t %15, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %20 = torch.aten.eq.int %19, %8 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %20 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If.yield %18 : !torch.bool\n" -" }\n" -" torch.prim.If %11 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %12 = torch.aten.eq.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" %14 = torch.prim.If %13 -> (!torch.tuple, list>) {\n" -" %15 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %16 = torch.prim.ListConstruct %15 : (!torch.int) -> !torch.list\n" -" %17 = torch.prim.TupleConstruct %16, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %17 : !torch.tuple, list>\n" -" } else {\n" -" %15 = torch.prim.TupleConstruct %9, %9 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" torch.prim.If.yield %15 : !torch.tuple, list>\n" -" }\n" -" return %14 : !torch.tuple, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" @@ -7349,38 +7389,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list\n" -" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %3 = torch.aten.sub.int %1, %2 : !torch.int, !torch.int -> !torch.int\n" -" %4 = torch.aten.ge.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.Loop %3, %true, init() {\n" -" ^bb0(%arg5: !torch.int):\n" -" %8 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.append.t %0, %8 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %6 = torch.aten.__range_length %3, %5, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" torch.prim.Loop %6, %true, init() {\n" -" ^bb0(%arg5: !torch.int):\n" -" %8 = torch.aten.append.t %0, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %7 = torch.prim.TupleConstruct %arg0, %0, %0 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" -" return %7 : !torch.tuple, list, list>\n" +" %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list, list>\n" +" return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_batch_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c218395cd1f2..f5de87671fb0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -171,35 +171,14 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, return sub; } -// Helper function to unsqueeze the input tensor at given dim. -// Return the unsqueezed tensor or failure. -static FailureOr unsqueezeTensor(PatternRewriter &rewriter, - Operation *op, Value input, Value dim) { - BaseTensorType inputType = input.getType().cast(); - if (!inputType.hasSizes()) { - return rewriter.notifyMatchFailure(op, "input tensor must have size"); - } - - SmallVector unsqueezedShape; - ArrayRef inputShape = inputType.getSizes(); - // `input` has a reduced rank. Hence add 1. - int64_t unsqueezedRank = inputShape.size() + 1; - int64_t dimInt = 0; - if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { - dimInt = toPositiveDim(dimInt, unsqueezedRank); - if (!isValidDim(dimInt, unsqueezedRank)) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - unsqueezedShape.append(inputShape.begin(), inputShape.end()); - unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1); - } else { - unsqueezedShape.resize(unsqueezedRank, kUnknownSize); - } - Type unsqueezedType = inputType.getWithSizesAndDtype( - unsqueezedShape, inputType.getOptionalDtype()); - Value unsqueezed = rewriter.create( - op->getLoc(), unsqueezedType, input, dim); - return unsqueezed; +static SmallVector computeDimsOrderForMoveDim(int64_t srcDimInt, + int64_t dstDimInt, + unsigned inputRank) { + llvm::iota_range dimsOrderIR(0, inputRank, /*inclusive=*/false); + SmallVector dimsOrder(dimsOrderIR.begin(), dimsOrderIR.end()); + dimsOrder.erase(dimsOrder.begin() + srcDimInt); + dimsOrder.insert(dimsOrder.begin() + dstDimInt, srcDimInt); + return dimsOrder; } namespace { @@ -234,12 +213,16 @@ class DecomposeAtenAmaxOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "Expected a constant boolean value for keepDim"); - Value input = op.getSelf(); + Value input = op.getSelf(); // For every dimension included in `dim` of the op, iterated over in // reverse order, we create a call to aten.max.dim. std::sort(dims.begin(), dims.end()); std::reverse(dims.begin(), dims.end()); for (int64_t dimInt : dims) { + int64_t inputRank = input.getType().cast().getSizes().size(); + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(dimInt)); // The input to the next invocation of aten.max.dim is the output of the @@ -1451,25 +1434,6 @@ class DecomposeAtenMaskedFillScalarOp return success(); } }; - -} // namespace -// Decompose aten.convolution_overrideable to aten.convolution op. -namespace { -class DecomposeAtenConvolutionOverrideableOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op, - PatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), - op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), - op.getOutputPadding(), op.getGroups()); - - return success(); - } -}; } // namespace // Decompose aten._convolution-like to aten.convolution @@ -1533,27 +1497,72 @@ class DecomposeAtenConvTranspose2dOp }; } // namespace -// Decompose aten.convolution_backward_overrideable to aten.convolution_backward -// op. -namespace { -class DecomposeAtenConvolutionBackwardOverrideableOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op, - PatternRewriter &rewriter) const override { - - Value none = rewriter.create(op->getLoc()); - rewriter.replaceOpWithNewOp( - op, op.getResultTypes(), op.getGradOutput(), op.getInput(), op.getWeight(), - none, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), - op.getOutputPadding(), op.getGroups(), op.getOutputMask()); - - return success(); - } -}; -} // namespace +static LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, + int64_t dimB, Type &transposedType) { + if (!inType.hasSizes()) + return failure(); + SmallVector shape(inType.getSizes()); + int64_t tmp = shape[0]; + shape[0] = shape[1]; + shape[1] = tmp; + transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), + inType.getOptionalDtype()); + return success(); +} +// The convolution backward op is decomposed as follows: +// inputH, inputW = input.shape[2:] +// output_padding_ = [ +// inputH +// - 1 +// + 2 * padding_[0] +// - dilation_[0] * (weight.shape[2] - 1) +// - (grad_output.shape[2] - 1) * stride_[0], +// inputW +// - 1 +// + 2 * padding_[1] +// - dilation_[1] * (weight.shape[3] - 1) +// - (grad_output.shape[3] - 1) * stride_[1], +// ] +// +// decomp_grad_input = torch.nn.functional.conv_transpose2d( +// grad_output, +// weight, +// None, +// stride_, +// padding_, +// output_padding_, +// groups_, +// dilation_, +// ) +// +// input_transposed = torch.ops.aten.transpose(input, 0, 1) +// grad_output_transposed = grad_output.view( +// grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] +// ) +// decomp_grad_weight = torch.ops.aten.convolution( +// input_transposed, +// grad_output_transposed, +// bias=None, +// stride=dilation_, +// padding=padding_, +// dilation=stride_, +// transposed=False, +// output_padding=[0, 0], +// groups=input.shape[0], +// ) +// decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) +// decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) +// decomp_grad_weight = decomp_grad_weight.view( +// input_transposed.shape[0], +// input_transposed.shape[1], +// grad_output.shape[1], +// *decomp_grad_weight.shape[2:] +// ) +// decomp_grad_weight = decomp_grad_weight.movedim(0, 2) +// decomp_grad_weight = decomp_grad_weight.sum(dim=0) +// +// decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) namespace { class DecomposeAtenConvolutionBackwardOp : public OpRewritePattern { @@ -1564,6 +1573,8 @@ class DecomposeAtenConvolutionBackwardOp Location loc = op.getLoc(); MLIRContext *context = op.getContext(); + Value input = op.getInput(); + Value weight = op.getWeight(); Value gradOutput = op.getGradOutput(); std::optional maybeGradRank = getTensorRank(gradOutput); if (!maybeGradRank) { @@ -1571,6 +1582,10 @@ class DecomposeAtenConvolutionBackwardOp "expected grad output to have a rank"); } unsigned gradRank = *maybeGradRank; + if (gradRank != 4) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D convolutions supported."); + Value cstZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create( @@ -1581,113 +1596,282 @@ class DecomposeAtenConvolutionBackwardOp Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); - Value input = op.getInput(); - Value weight = op.getWeight(); + SmallVector padding, dilation, stride; + SmallVector paddingInt, dilationInt, strideInt, + outputPaddingInt; - if (gradRank != 4) + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolutions supported."); + op, "padding must be a list of constant ints"); - SmallVector padding; - if (!getListConstructElements(op.getPadding(), padding)) - return rewriter.notifyMatchFailure(op, "padding must be a list."); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) + return rewriter.notifyMatchFailure( + op, "stride must be a list of constant ints"); - SmallVector strides; - if (!getListConstructElements(op.getStride(), strides)) - return rewriter.notifyMatchFailure(op, "stride must be a list."); - for (Value stride : strides) { - Value cmp = rewriter.create(loc, stride, cstOne); - rewriter.create( - loc, cmp, "unimplemented: only strides of 1 supported."); - } + if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) + return rewriter.notifyMatchFailure( + op, "dilation must be a list of constant ints"); + if (!llvm::all_of(dilationInt, + [](int64_t dilationVal) { return dilationVal == 1; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only dilations of 1 supported."); - SmallVector dilations; - if (!getListConstructElements(op.getDilation(), dilations)) - return rewriter.notifyMatchFailure(op, "dilation must be a list."); - for (Value dilation : dilations) { - Value cmp = rewriter.create(loc, dilation, cstOne); - rewriter.create( - loc, cmp, "unimplemented: only dilations of 1 supported."); - } + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInt))) + return rewriter.notifyMatchFailure( + op, "output padding must be a list of constant ints"); + if (!llvm::all_of(outputPaddingInt, + [](int64_t outPad) { return outPad == 0; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only output padding of 0 supported."); SmallVector outMask; if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask))) return rewriter.notifyMatchFailure( op, "only constant bool output_mask is supported."); - // Support for `False` values for output mask unimplemented. - if (!llvm::all_of(outMask, [](bool mask) { return mask; })) - return rewriter.notifyMatchFailure( - op, "unimplemented: only true values for output_mask supported."); + for (unsigned i = 0; i < outMask.size(); i++) { + if (outMask[i] == false) { + Value result = op->getResults()[i]; + if (!result.getUsers().empty()) + return rewriter.notifyMatchFailure( + op, "unimplemented: false value supported for output_mask only " + "when the result tensor corresponding to that has no users."); + } + } bool transposed; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( - op, "only constant transposed is supported."); + op, "transposed arg should be a constant bool."); if (transposed) return rewriter.notifyMatchFailure( op, "unimplemented: transposed convolutions are not supported."); - // Rotate weight. - SmallVector axes; - for (unsigned i = 2; i < gradRank; i++) { - axes.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - } - Value axesList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), axes); - weight = rewriter.create(loc, weight.getType(), weight, - axesList); - // Calculate padding for first convolution. - SmallVector gradInputPaddingValues; + getListConstructElements(op.getPadding(), padding); + getListConstructElements(op.getStride(), stride); + getListConstructElements(op.getDilation(), dilation); + + // Computing Grad Input. + // Calculate output padding for first convolution. + // output_padding_ = [ + // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] + // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * + // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - + // ((grad_out.size()[3] - 1) * stride_[1]), + // ] + SmallVector outputPaddingValues; for (unsigned i = 2; i < gradRank; i++) { Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); - Value outDim = rewriter.create(loc, input, dim); - - // Calculate 1 + (weightDim // 2) * 2, which fixes issues with - // even-sized weight. - Value weightDim = rewriter.create(loc, weight, dim); - weightDim = - rewriter.create(loc, weightDim, cstTwo); - weightDim = rewriter.create(loc, weightDim, cstTwo); - weightDim = rewriter.create(loc, weightDim, cstOne); + Value inputVecDim = + rewriter.create(loc, input, dim); Value gradOutDim = rewriter.create(loc, gradOutput, dim); + Value weightDim = rewriter.create(loc, weight, dim); + Value inputVecDimMinusOne = + rewriter.create(loc, inputVecDim, cstOne); + Value gradOutDimMinusOne = + rewriter.create(loc, gradOutDim, cstOne); + Value weightDimMinusOne = + rewriter.create(loc, weightDim, cstOne); + Value twoTimesPadding = + rewriter.create(loc, padding[i - 2], cstTwo); + Value tmpA = rewriter.create(loc, weightDimMinusOne, + dilation[i - 2]); + Value tmpB = rewriter.create(loc, gradOutDimMinusOne, + stride[i - 2]); + Value outputPaddingVal = rewriter.create( + loc, inputVecDimMinusOne, twoTimesPadding); + outputPaddingVal = + rewriter.create(loc, outputPaddingVal, tmpA); + outputPaddingVal = + rewriter.create(loc, outputPaddingVal, tmpB); + outputPaddingValues.push_back(outputPaddingVal); + } + Value outputPaddingForGradInput = + rewriter.create( + loc, ListType::get(IntType::get(context)), outputPaddingValues); + + Value gradInput = rewriter.create( + loc, op.getResultTypes()[0], gradOutput, weight, cstNone, + op.getStride(), op.getPadding(), outputPaddingForGradInput, + op.getGroups(), op.getDilation()); + + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), 0, 1, + transposedType))) + return failure(); + Value inputTransposed = rewriter.create( + loc, transposedType, input, cstZero, cstOne); + + // For the cases where the stride is non-unit, we compute the `GradWeight` + // through this implementation. + Value gradWeight; + if (!llvm::all_of(strideInt, [](int64_t stride) { return stride == 1; })) { + // Computing Grad Weight. + SmallVector gradOutputSize; + for (unsigned i = 0; i < gradRank; i++) { + gradOutputSize.push_back(rewriter.create( + loc, gradOutput, + rewriter.create( + loc, rewriter.getI64IntegerAttr(i)))); + } - // Calculate (((outDim - 1) * stride) + weightDim - gradOutDim) // 2, - // the padding value for this dimension. Derived from the formula at - // https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html - Value padVal = rewriter.create(loc, outDim, cstOne); - padVal = - rewriter.create(loc, padVal, strides[i - 2]); - padVal = rewriter.create(loc, padVal, weightDim); - padVal = rewriter.create(loc, padVal, gradOutDim); - padVal = rewriter.create(loc, padVal, cstTwo); - - gradInputPaddingValues.push_back(padVal); - } - Value gradInputPadding = rewriter.create( - loc, ListType::get(IntType::get(context)), gradInputPaddingValues); - Value weightTransposed = rewriter.create( - loc, weight.getType(), weight, cstZero, cstOne); - // Convolve grad_output with weight. - Value gradInput = rewriter.create( - loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone, - op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(), - op.getOutputPadding(), op.getGroups()); + Value gradOutputViewDimZero = rewriter.create( + loc, gradOutputSize[0], gradOutputSize[1]); + Value gradOutputViewShapeList = + rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], + gradOutputSize[3]}); + + BaseTensorType gradOutputTy = gradOutput.getType().cast(); + if (!gradOutputTy.hasSizes()) + return failure(); + SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); + SmallVector gradOutputViewSizesInt(gradOutputSizesInt); + if (gradOutputViewSizesInt[0] != kUnknownSize && + gradOutputViewSizesInt[1] != kUnknownSize) + gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; + else + gradOutputViewSizesInt[0] = kUnknownSize; + gradOutputViewSizesInt[1] = 1; + BaseTensorType gradOutputTypeForView = + gradOutputTy + .getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), + gradOutputTy.getOptionalDtype()) + .cast(); + Value gradOutputView = rewriter.create( + loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); + + BaseTensorType inputTransposedTy = + inputTransposed.getType().cast(); + if (!inputTransposedTy.hasSizes()) + return failure(); + SmallVector inputTransposedSizesInt( + inputTransposedTy.getSizes()); + SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], + gradOutputViewSizesInt[0]}; + for (unsigned i = 2; i < gradRank; i++) { + if (inputTransposedSizesInt[i] != kUnknownSize && + gradOutputViewSizesInt[i] != kUnknownSize) { + int64_t kernelSizeInt = + strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; + gradWeightSizesInt.push_back( + ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - + kernelSizeInt) / + dilationInt[i - 2]) + + 1); + } else { + gradWeightSizesInt.push_back(kUnknownSize); + } + } - Value gradOutputTransposed = rewriter.create( - loc, gradOutput.getType(), gradOutput, cstZero, cstOne); - Value inputTransposed = rewriter.create( - loc, input.getType(), input, cstZero, cstOne); - // Convolve input with grad_output. - Value gradWeight = rewriter.create( - loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed, - cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), - op.getOutputPadding(), op.getGroups()); - gradWeight = rewriter.create( - loc, gradWeight.getType(), gradWeight, cstZero, cstOne); + BaseTensorType gradWeightTy = + inputTransposedTy + .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), + inputTransposedTy.getOptionalDtype()) + .cast(); + Value numGroup = rewriter.create(loc, input, cstZero); + gradWeight = rewriter.create( + loc, gradWeightTy, inputTransposed, gradOutputView, cstNone, + /*stride=*/op.getDilation(), op.getPadding(), + /*dilation=*/op.getStride(), op.getTransposed(), + op.getOutputPadding(), numGroup); + + BaseTensorType weightTy = weight.getType().cast(); + if (!weightTy.hasSizes()) + return failure(); + SmallVector weightSizes(weightTy.getSizes()); + for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { + gradWeightSizesInt[i + 2] = weightSizes[i + 2]; + BaseTensorType gradWeightNarrowTy = + gradWeightTy + .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), + gradWeightTy.getOptionalDtype()) + .cast(); + + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i + 2)); + Value length = rewriter.create(loc, weight, dim); + gradWeight = rewriter.create( + loc, gradWeightNarrowTy, gradWeight, dim, /*start=*/cstZero, + length); + } + + SmallVector gradWeightViewShapeInt{ + inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; + gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); + gradWeightViewShapeInt.insert( + gradWeightViewShapeInt.end(), + {gradWeightSizesInt[2], gradWeightSizesInt[3]}); + + SmallVector gradWeightViewShapeValue; + for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { + gradWeightViewShapeValue.push_back( + rewriter.create( + loc, rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); + } + + Value gradWeightViewShapeList = + rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + gradWeightViewShapeValue); + + BaseTensorType gradWeightTypeForView = + gradWeightTy + .getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), + gradWeightTy.getOptionalDtype()) + .cast(); + gradWeight = rewriter.create( + loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); + + gradWeightTy = gradWeight.getType().cast(); + SmallVector gradWeightDimsOrder = + computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); + SmallVector gradWeightMoveDimShape; + for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { + gradWeightMoveDimShape.push_back( + gradWeightViewShapeInt[gradWeightDimsOrder[i]]); + } + BaseTensorType gradWeightTypeForMoveDim = + gradWeightTy + .getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), + gradWeightTy.getOptionalDtype()) + .cast(); + + gradWeight = rewriter.create( + loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, + /*destination=*/cstTwo); + + Value gradIntList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + llvm::ArrayRef{cstZero}); + gradWeight = rewriter.create( + loc, op.getResultTypes()[1], /*self=*/gradWeight, /*dim=*/gradIntList, + /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + } else { + if (failed(getTransposedType(gradOutput.getType().cast(), + 0, 1, transposedType))) + return failure(); + Value gradOutputTransposed = rewriter.create( + loc, transposedType, gradOutput, cstZero, cstOne); + // Convolve input with grad_output. + if (failed( + getTransposedType(op.getResultTypes()[1].cast(), + 0, 1, transposedType))) + return failure(); + gradWeight = rewriter.create( + loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, + op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), + op.getOutputPadding(), op.getGroups()); + gradWeight = rewriter.create( + loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); + } + + // Computing Grad Bias. SmallVector dimIntList{cstZero}; for (unsigned i = 2; i < gradRank; i++) dimIntList.push_back(rewriter.create( @@ -1695,6 +1879,7 @@ class DecomposeAtenConvolutionBackwardOp Value gradIntList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dimIntList); + // Sum grad_output along dim 1. Value gradBias = rewriter.create( loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse, @@ -3633,6 +3818,28 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenRandintOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRandintOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Type resultType = op.getType(); + + Value low = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + rewriter.replaceOpWithNewOp( + op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), op.getLayout(), + op.getDevice(), op.getPinMemory()); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.varMean.correction` op into `aten.var.correction` and // `aten.mean.dim` op. @@ -3880,6 +4087,282 @@ class DecomposeAtenNewEmptyStridedOp }; } // namespace +namespace { +class DecomposePrimsSqueezeOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsSqueezeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getA(); + SmallVector dimensions; + if (!matchPattern(op.getDimensions(), + m_TorchListOfConstantInts(dimensions))) + return rewriter.notifyMatchFailure( + op, "all dimensions must be constant ints"); + + std::sort(dimensions.begin(), dimensions.end()); + std::reverse(dimensions.begin(), dimensions.end()); + + if (dimensions.size() == 0) { + rewriter.replaceOp(op, input); + return success(); + } + Value result = input; + for (unsigned i = 0; i < dimensions.size(); i++) { + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, loc, dimensions[i], result); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + result = *squeezeTensorInfo; + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenMovedimIntOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMovedimIntOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + std::optional maybeInputRank = getTensorRank(input); + if (!maybeInputRank) { + return rewriter.notifyMatchFailure( + op, "expected input tensor to have a rank"); + } + unsigned inputRank = *maybeInputRank; + if (inputRank <= 1) { + rewriter.replaceOp(op, input); + return success(); + } + + int64_t srcDimInt, dstDimInt; + if (matchPattern(op.getSource(), m_TorchConstantInt(&srcDimInt))) { + srcDimInt = toPositiveDim(srcDimInt, inputRank); + if (!isValidDim(srcDimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "source is not a valid dim"); + } else { + return rewriter.notifyMatchFailure(op, "source is not a constant int"); + } + if (matchPattern(op.getDestination(), m_TorchConstantInt(&dstDimInt))) { + dstDimInt = toPositiveDim(dstDimInt, inputRank); + if (!isValidDim(dstDimInt, inputRank)) + return rewriter.notifyMatchFailure(op, + "destination is not a valid dim"); + } else { + return rewriter.notifyMatchFailure(op, + "destination is not a constant int"); + } + + SmallVector dimsOrder = + computeDimsOrderForMoveDim(srcDimInt, dstDimInt, inputRank); + SmallVector cstDimsOrder; + for (int64_t dim : dimsOrder) + cstDimsOrder.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(dim))); + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + cstDimsOrder); + rewriter.replaceOpWithNewOp(op, op.getType(), input, + permuteDimsOrder); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenCrossEntropyLossOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCrossEntropyLossOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value target = op.getTarget(); + std::optional maybeRank = getTensorRank(self); + if (!maybeRank) + return rewriter.notifyMatchFailure( + op, "Unimplemented: unranked input tensor"); + unsigned selfRank = maybeRank.value(); + maybeRank = getTensorRank(target); + if (!maybeRank) + return rewriter.notifyMatchFailure( + op, "Unimplemented: unranked target tensor"); + unsigned targetRank = maybeRank.value(); + + // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d + // of the form [minibatch] the cross entropy loss decomposes to the + // combination of softmax and nll loss as follows: + // cross_entropy_loss = NLLLoss(LogSoftmax(input, dim=1), target) + // Currently, we only support the above-mentioned case. + if (selfRank != 2 || targetRank != 1) { + return rewriter.notifyMatchFailure( + op, + "unimplemented: only support cases with 2-d input and 1-d target"); + } + + // TODO: Add support for label_smoothing value other than 0.0 (default + // value). + double labelSmoothing; + if (!matchPattern(op.getLabelSmoothing(), + m_TorchConstantFloat(&labelSmoothing))) { + return rewriter.notifyMatchFailure( + op, "Only support constant float label_smoothing value"); + } else if (labelSmoothing != 0.0) { + return rewriter.notifyMatchFailure(op, + "unimplemented: only support default " + "value of 0.0 for label_smoothing"); + } + + Value noneVal = rewriter.create(loc); + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value logSoftmax = rewriter.create( + loc, self.getType(), self, dim, /*dtype=*/noneVal); + Value nllLoss = + rewriter + .create( + loc, op.getType(), target.getType(), logSoftmax, target, + op.getWeight(), op.getReduction(), op.getIgnoreIndex()) + ->getResult(0); + rewriter.replaceOp(op, nllLoss); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenOneHotOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOneHotOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + Value input = op.getSelf(); + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "input tensor should have known sizes."); + int64_t inputRank = inputType.getSizes().size(); + int64_t numClasses; + if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_classes must be constant"); + Value none = rewriter.create(loc); + Value falseValue = rewriter.create(loc, false); + + // arange tensor + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto arangeType = + ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type); + Value arangeTensor = rewriter.create( + loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + // unsqueeze input + llvm::SmallVector unsqueezeShape(inputType.getSizes()); + unsqueezeShape.push_back(1); + auto unsqueezeType = + ValueTensorType::get(context, unsqueezeShape, si64Type); + Value unsqueezeTensor = rewriter.create( + loc, unsqueezeType, input, + rewriter.create(loc, + rewriter.getI64IntegerAttr(inputRank))); + + // compare + auto eqType = ValueTensorType::get( + context, op.getType().cast().getSizes(), + IntegerType::get(context, 1)); + Value eqTensor = rewriter.create( + loc, eqType, unsqueezeTensor, arangeTensor); + + // convert to si64 + Value si64TypeValue = + Torch::getDtypeIntValueForType(rewriter, loc, si64Type); + Value result = rewriter.create( + loc, op.getType(), eqTensor, si64TypeValue, /*non_blocking=*/falseValue, + /*copy=*/falseValue, /*memory_format=*/none); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +namespace { +// Decompose `aten.var_mean.dim` op into `aten.var.dim` and +// `aten.mean.dim` op. +class DecomposeAtenVarMeanDimOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenVarMeanDimOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value noneVal = rewriter.create(loc); + Value var = rewriter.create(loc, op.getType(0), op.getSelf(), + op.getDim(), op.getUnbiased(), + op.getKeepdim()); + Value mean = rewriter.create( + loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), + /*dtype=*/noneVal); + rewriter.replaceOp(op, {var, mean}); + return success(); + } +}; +} // namespace + +namespace { +// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op. +class DecomposeAtenTopkOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTopkOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + bool sorted; + if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for sorted"); + if (!sorted) + return rewriter.notifyMatchFailure( + op, "unimplemented: sorted value arg must be set to True"); + + Value self = op.getSelf(); + Value dim = op.getDim(); + auto selfType = self.getType().cast(); + auto sortIndicesType = selfType.getWithSizesAndDtype( + selfType.getOptionalSizes(), + IntegerType::get(context, 64, IntegerType::Signed)); + auto sortOpResult = rewriter.create( + loc, self.getType(), sortIndicesType, self, dim, + /*descending=*/op.getLargest()); + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value step = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value resultValue = rewriter.create( + loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start, + /*end=*/op.getK(), step); + Value resultIndices = rewriter.create( + loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start, + /*end=*/op.getK(), step); + rewriter.replaceOp(op, {resultValue, resultIndices}); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -3898,7 +4381,7 @@ class DecomposeComplexOpsPass // on `Operation *` are not allowed, since there is no way of telling if // that pattern will match on an op in the `legalOpsSet` or not. assert(opName && "All decomposition patterns must target a single op"); - if (!legalOpsSet.contains(opName->getStringRef())) + if (!legalOpsSet.contains(opName->getStringRef().ltrim(kTorchOpPrefix))) patterns.add(context); } @@ -3934,8 +4417,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenConvolutionBackwardOverrideableOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( @@ -3957,8 +4438,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( - patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); addPatternIfTargetOpIsIllegal< @@ -4028,6 +4507,7 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -4041,6 +4521,12 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp index 61b79d6601cb..c3236c0324d1 100644 --- a/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/DropAbstractInterpCalculations.cpp @@ -29,7 +29,7 @@ class DropCalculateOp : public OpConversionPattern { Block *block = &op.getBody().front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op); + rewriter.inlineBlockBefore(block, op); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); return success(); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 1f21a36568ef..ac077ca2f831 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -18,6 +18,7 @@ #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/Support/Debug.h" +#include "llvm/ADT/StringSet.h" #define DEBUG_TYPE "torch-lower-to-backend-contract" @@ -31,7 +32,7 @@ using namespace mlir::torch::Torch; static void markDecomposedOpsAsIllegal(MLIRContext *context, ConversionTarget &target, - ArrayRef backendLegalOps); + llvm::StringSet<> backendLegalOps); static LogicalResult checkType(Operation *op, Type type, bool actuallyEmitDiagnostics) { @@ -197,6 +198,24 @@ static bool satisfiesBackendContract(ModuleOp module, if (walkResult0.wasInterrupted()) return false; + // Check for unimplemented operators first to give more direct diagnostics. + walkResult0 = module.walk([&](Torch::OperatorOp op) { + if (llvm::all_of(op.getResults(), [&op](auto res) { + return succeeded( + checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + })) { + return WalkResult::advance(); + } + + if (actuallyEmitDiagnostics) { + op->emitError("unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); + } + return WalkResult::interrupt(); + }); + if (walkResult0.wasInterrupted()) + return false; + // Check all the types of all Value's in the program and the legality of all // the ops. // @@ -228,11 +247,11 @@ static bool satisfiesBackendContract(ModuleOp module, // Explicitly set ops and dialects allowed and not allowed in backend contract. static ConversionTarget getBackendContractTarget(MLIRContext *context, bool decompose, - ArrayRef backendLegalOps) { + llvm::StringSet<> backendLegalOpsSet) { ConversionTarget target(*context); target.addLegalDialect(); if (decompose) - markDecomposedOpsAsIllegal(context, target, backendLegalOps); + markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet); return target; } @@ -242,21 +261,27 @@ class LowerToBackendContractPass public: LowerToBackendContractPass() = default; LowerToBackendContractPass(int maxIterations, bool decompose, - ArrayRef backendLegalOps) { + ArrayRef backendLegalOps, + StringRef extraLibrary) { this->maxIterations = maxIterations; this->decompose = decompose; this->backendLegalOps = backendLegalOps; + this->extraLibrary = extraLibrary.str(); } void runOnOperation() override { ModuleOp module = getOperation(); MLIRContext *context = &getContext(); + + backendLegalOpsSet.clear(); + backendLegalOpsSet.insert(backendLegalOps.begin(), backendLegalOps.end()); ConversionTarget target = - getBackendContractTarget(context, decompose, backendLegalOps); + getBackendContractTarget(context, decompose, backendLegalOpsSet); OpPassManager pm(module.getOperationName()); TorchLoweringPipelineOptions options; options.decompose = decompose; options.backendLegalOps = backendLegalOps; + options.extraLibrary = extraLibrary; createTorchSimplificationPipeline(pm, options); int i = 0; @@ -283,6 +308,8 @@ class LowerToBackendContractPass << " iterations of the simplification pipeline\n"; }); } +private: + llvm::StringSet<> backendLegalOpsSet; }; class VerifyBackendContractNoDecompositionsPass @@ -294,7 +321,7 @@ class VerifyBackendContractNoDecompositionsPass MLIRContext *context = &getContext(); ConversionTarget target = getBackendContractTarget(context, /*decompose*/false, - /*backendLegalOps*/{}); + /*backendLegalOpsSet*/{}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { @@ -306,9 +333,10 @@ class VerifyBackendContractNoDecompositionsPass std::unique_ptr> mlir::torch::Torch::createLowerToBackendContractPass( - int maxIterations, bool decompose, ArrayRef backendLegalOps) { - return std::make_unique(maxIterations, decompose, - backendLegalOps); + int maxIterations, bool decompose, ArrayRef backendLegalOps, + StringRef extraLibrary) { + return std::make_unique( + maxIterations, decompose, backendLegalOps, extraLibrary); } std::unique_ptr> @@ -319,9 +347,9 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { // The backend contract guarantees that ops with decompositions available will // be decomposed. The only way to have an op reach the backend contract without // getting decomposed is by having the user explicitly specify that op in the -// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore, +// `backendLegalOpsSet` argument to the `LowerToBackendContractPass`. Therefore, // here we mark as illegal all ops with decompositions except for those in -// `backendLegalOps`. +// `backendLegalOpsSet`. // // The legality check takes place here instead of in the `DecomposeComplexOps` // pass for two reasons: @@ -334,7 +362,7 @@ mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() { // decompositions explicit in this file static void markDecomposedOpsAsIllegal(MLIRContext *context, ConversionTarget &target, - ArrayRef backendLegalOps) { + llvm::StringSet<> backendLegalOpsSet) { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -351,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -377,7 +404,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -435,6 +461,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -445,7 +472,19 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - for (std::string opName : backendLegalOps) { - target.addLegalOp(OperationName(opName, context)); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + for (auto &opName : backendLegalOpsSet) { + target.addLegalOp( + OperationName(kTorchOpPrefix + opName.first().str(), context)); } + target.addDynamicallyLegalOp( + [backendLegalOpsSet](OperatorOp opOp) { + auto opName = opOp->getAttr("name").cast().getValue(); + return backendLegalOpsSet.contains(opName); + }); } diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 934ff7c25281..5ed5d53bd06d 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -25,7 +25,7 @@ void mlir::torch::registerTorchPasses() { "torch-simplification-pipeline", "Pipeline simplifying computations in the program.", mlir::torch::Torch::createTorchSimplificationPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.", mlir::torch::Torch::createTorchShapeRefinementPipeline); } @@ -66,7 +66,8 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( // Perform the bulk of lowering to the backend contract. // See the pass documentation for more information. pm.addPass(createLowerToBackendContractPass( - options.maxIterations, options.decompose, options.backendLegalOps)); + options.maxIterations, options.decompose, options.backendLegalOps, + options.extraLibrary)); } // A simplification pipeline to establish the invariants of the backend @@ -106,9 +107,10 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // Clean up again to avoid needing to to back around the fixed-point // iteration. pm.addNestedPass(createCanonicalizerPass()); - pm.addNestedPass(createRecomposeComplexOps()); + pm.addNestedPass(createRecomposeComplexOpsPass()); // Reduce variants of ops to a smaller set of primitives. - pm.addNestedPass(createReduceOpVariantsPass()); + pm.addNestedPass( + createReduceOpVariantsPass(options.extraLibrary)); pm.addNestedPass(createCanonicalizerPass()); // Remove dead global slots. pm.addPass(createSymbolDCEPass()); @@ -121,8 +123,8 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( // This should be run before RefineTypes (which primarily does dtype // inference), because Torch type promotion rules actually depend on the shape // of the operand. - createTorchShapeRefinementPipeline(pm); - createTorchDtypeRefinementPipeline(pm); + createTorchShapeRefinementPipeline(pm, options); + createTorchDtypeRefinementPipeline(pm, options); // Refine types in the program, which mainly means inferring dtypes of ops. pm.addNestedPass(Torch::createRefineTypesPass()); // Propagate to ABI return types the shape/dtype information discovered by @@ -141,13 +143,15 @@ void mlir::torch::Torch::createTorchSimplificationPipeline( static void createRefinementPipeline( mlir::OpPassManager &pm, - llvm::function_ref>()> + llvm::function_ref< + std::unique_ptr>(llvm::StringRef)> reifyCalculationsPass, llvm::function_ref< std::unique_ptr>()> - simplifyCalculationsPass) { + simplifyCalculationsPass, + const mlir::torch::Torch::TorchLoweringPipelineOptions &options) { // Reify the library functions for each op that is present in the library. - pm.addPass(reifyCalculationsPass()); + pm.addPass(reifyCalculationsPass(options.extraLibrary)); // Inline the library functions to enable analysis and transformation. // TODO: Only inline library functions (this will currently inline @@ -168,12 +172,14 @@ static void createRefinementPipeline( mlir::torch::Torch::createDropAbstractInterpCalculationsPass()); } -void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) { +void mlir::torch::Torch::createTorchShapeRefinementPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { createRefinementPipeline(pm, Torch::createReifyShapeCalculationsPass, - Torch::createSimplifyShapeCalculationsPass); + Torch::createSimplifyShapeCalculationsPass, options); } -void mlir::torch::Torch::createTorchDtypeRefinementPipeline(OpPassManager &pm) { +void mlir::torch::Torch::createTorchDtypeRefinementPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { createRefinementPipeline(pm, Torch::createReifyDtypeCalculationsPass, - Torch::createSimplifyDtypeCalculationsPass); + Torch::createSimplifyDtypeCalculationsPass, options); } diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index 7a5269946a48..dbddcc312927 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -71,19 +72,68 @@ class RecomposeSliceCopy_ : public OpRewritePattern { return success(); } }; + +class RecomposeSelectFill_ : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFill_TensorOp op, + PatternRewriter &rewriter) const override { + if (!op.getSelf().getDefiningOp() || + !isa(op.getSelf().getDefiningOp())) + return failure(); + auto selectOp = cast(op.getSelf().getDefiningOp()); + + // Get indices + int64_t dim; + if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim))) + return failure(); + + Value noneVal = rewriter.create(op.getLoc()); + Value falseVal = rewriter.create(op.getLoc(), false); + + // Create IndexPut_Op + // Convert indexNum to indexTensor for the selectOp + BaseTensorType selectOutTy = + selectOp.getType().template cast(); + SmallVector empty; + auto dtype = getTypeForTorchType(selectOp.getContext(), + selectOp.getIndex().getType()); + Type emptyTensorType = + selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); + Value indexTensor = rewriter.create( + selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); + + // Create indicesVector for IndexPut_Op by TorchNone and indexTensor + BaseTensorType tensorType = op->getResultTypes()[0].cast(); + SmallVector indicesVector(dim - 1, noneVal); + indicesVector.push_back(indexTensor); + + Value indices = rewriter.create( + op.getLoc(), + Torch::ListType::get(op->getContext(), + Torch::OptionalType::get(tensorType)), + indicesVector); + + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), selectOp.getSelf(), indices, op.getValue(), + /*accumulate=*/falseVal, /*unsafe=*/falseVal); + + return success(); + } +}; } // namespace namespace { -class RecomposeComplexOps - : public DecomposeComplexOpsBase { +class RecomposeComplexOpsPass + : public RecomposeComplexOpsBase { public: - RecomposeComplexOps() = default; void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); // pattern.add calls go here patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; @@ -98,6 +148,6 @@ class RecomposeComplexOps } // namespace std::unique_ptr> -mlir::torch::Torch::createRecomposeComplexOps() { - return std::make_unique(); +mlir::torch::Torch::createRecomposeComplexOpsPass() { + return std::make_unique(); } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 78bc0703e46c..3b30b8a82722 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -12,6 +12,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "ReifyAbstractInterpCalculationsUtils.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -52,17 +53,39 @@ static Type getContainerOrTensorTypeWithValueSemantics(Type type) { } } +static bool +operatorOpHasValueSemantics(OperatorOp opOp, + std::optional extraLibrary) { + if (!extraLibrary.has_value()) + return false; + auto opName = opOp->getAttr("name").cast().getValue(); + std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix( + LibraryFunctionKind::HasValueSemantics) + + Twine(opName)) + .str(); + auto libFunc = extraLibrary->lookup(libFuncName); + return bool(libFunc); +} + namespace { // Convert value semantic ops operating on mutable arrays to instead operate on // immutable tensors. class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: - ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context, + const std::optional& extraLibrary) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { + this->extraLibrary = extraLibrary; + } LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!op->hasTrait()) + if (isa(op)) { + if (!operatorOpHasValueSemantics(cast(op), extraLibrary)) { + return rewriter.notifyMatchFailure(op, "does not have value semantics"); + } + } else if (!op->hasTrait()) { return rewriter.notifyMatchFailure(op, "does not have value semantics"); + } rewriter.startRootUpdate(op); // Convert all operands. @@ -160,6 +183,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { rewriter.finalizeRootUpdate(op); return success(); } +private: + std::optional extraLibrary; }; } // namespace @@ -241,11 +266,30 @@ reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op, } namespace { -class ReduceOpVariantsPass : public ReduceOpVariantsBase { +struct ReduceOpVariantsPass + : public ReduceOpVariantsBase { + ReduceOpVariantsPass() = default; + ReduceOpVariantsPass(StringRef extraLibrary) { + this->extraLibrary = extraLibrary.str(); + } void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context); + OwningOpRef extraLibraryModule = + ModuleOp::create(UnknownLoc::get(context)); + std::optional extraLibraryModuleSymTable = std::nullopt; + if (!extraLibrary.empty()) { + if (failed(loadExtraLibrary(extraLibrary, extraLibraryModule))) { + emitError(getOperation()->getLoc(), + "Failed to load extra-library file at " + extraLibrary); + return signalPassFailure(); + } + + extraLibraryModuleSymTable = + SymbolTable(extraLibraryModule->getOperation()); + } + patterns.add( + context, extraLibraryModuleSymTable); patterns.add(context); patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(context); @@ -253,8 +297,12 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); - target.markUnknownOpDynamicallyLegal([](Operation *op) { - if (op->hasTrait()) { + target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable]( + Operation *op) { + if (op->hasTrait() || + (isa(op) && + operatorOpHasValueSemantics(cast(op), + extraLibraryModuleSymTable))) { auto hasValueSemantics = [](Type t) { // TODO: Make this an allowlist based on a closed torch dialect // type system. @@ -281,6 +329,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase { } // namespace std::unique_ptr> -mlir::torch::Torch::createReduceOpVariantsPass() { - return std::make_unique(); +mlir::torch::Torch::createReduceOpVariantsPass(StringRef extraLibrary) { + return std::make_unique(extraLibrary); } diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 7e3697302c5b..e888d0892710 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -8,18 +8,25 @@ //===----------------------------------------------------------------------===// #include "ReifyAbstractInterpCalculationsUtils.h" +#include "mlir/Parser/Parser.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/ErrorOr.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -static std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) { +std::string +mlir::torch::Torch::getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) return "__torch_mlir_shape_fn."; else if (libFuncKind == LibraryFunctionKind::DtypeFunction) return "__torch_mlir_dtype_fn."; + else if (libFuncKind == LibraryFunctionKind::HasValueSemantics) + return "__torch_mlir_has_value_semantics_fn."; llvm_unreachable( "`getLibraryFunctionPrefix` called with an unsupported `CalculateOp`"); } @@ -73,6 +80,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( // looking them up in the library. if (name.startswith("valsem.")) name = name.drop_front(strlen("valsem.")); + if (isa(op)) + name = cast(op)->getAttr("name").cast().getValue(); std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); @@ -300,3 +309,39 @@ FailureOr Torch::adjustFunctionArg( // Pass the operand as-is. return operand; } + +LogicalResult +mlir::torch::Torch::loadExtraLibrary(const std::string &filename, + OwningOpRef &moduleToAppendTo) { + auto ctx = moduleToAppendTo->getContext(); + assert(ctx && "Module should be fully initialized."); + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(filename); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << ec.message() << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + OwningOpRef module_ = + mlir::parseSourceFile(sourceMgr, ctx); + if (!module_) { + llvm::errs() << "Error can't load file " << filename << "\n"; + return failure(); + } + + assert((moduleToAppendTo->getBodyRegion().empty() || + moduleToAppendTo->getBodyRegion().hasOneBlock()) && + "Module should have at most one block."); + if (moduleToAppendTo->getBodyRegion().empty()) { + moduleToAppendTo = std::move(module_); + } else { + Block *block = moduleToAppendTo->getBody(0); + block->getOperations().splice(block->end(), + module_->getBody(0)->getOperations()); + } + + return success(); +} diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.h b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.h index bef6676147f3..fa336407829c 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.h +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.h @@ -22,7 +22,12 @@ namespace mlir { namespace torch { namespace Torch { -enum class LibraryFunctionKind { ShapeFunction, DtypeFunction, Decomposition }; +enum class LibraryFunctionKind { + ShapeFunction, + DtypeFunction, + Decomposition, + HasValueSemantics +}; // Searches the function library for an abstract interpretation function for // `op`. If one is found, wraps the op in a `CalculateOp`, with the op placed in @@ -60,6 +65,16 @@ FailureOr adjustFunctionArg( OpBuilder &b, Location loc, Value operand, Type desiredType, function_ref baseTransformation = [](OpBuilder &, Location, Value operand, Type) { return operand; }); + +std::string getLibraryFunctionPrefix(LibraryFunctionKind libFuncKind); + +// Parse MLIR module at `filename` into a ModuleOp that will then +// be appended to an existing, fully hydrated, ModuleOp; note the module +// should have been instantiated with an associated context like so: +// `OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context));` +LogicalResult loadExtraLibrary(const std::string &filename, + OwningOpRef &moduleToAppendTo); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index 9eac538743b4..ac6d1ceac363 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -61,13 +61,23 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, } namespace { -class ReifyDtypeCalculationsPass +struct ReifyDtypeCalculationsPass : public ReifyDtypeCalculationsBase { + ReifyDtypeCalculationsPass() = default; + ReifyDtypeCalculationsPass(StringRef extraLibrary) { + this->extraLibrary = extraLibrary.str(); + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); OwningOpRef library = parseSourceString(getAbstractInterpLibrary(), context); + if (!extraLibrary.empty()) + if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) { + emitError(module->getLoc(), + "Failed to load extra-library file at " + extraLibrary); + return signalPassFailure(); + } // Walk all the operations, and if we have a dtype function, wrap the op // in a `torch.dtype.calculate` op. @@ -86,6 +96,6 @@ class ReifyDtypeCalculationsPass } // namespace std::unique_ptr> -Torch::createReifyDtypeCalculationsPass() { - return std::make_unique(); +Torch::createReifyDtypeCalculationsPass(StringRef extraLibrary) { + return std::make_unique(extraLibrary); } diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index 8ace4e5e1bdf..f755b5c0a405 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "llvm/Support/MemoryBuffer.h" using namespace mlir; using namespace mlir::torch; @@ -55,8 +56,12 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, } namespace { -class ReifyShapeCalculationsPass +struct ReifyShapeCalculationsPass : public ReifyShapeCalculationsBase { + ReifyShapeCalculationsPass() = default; + ReifyShapeCalculationsPass(StringRef extraLibrary) { + this->extraLibrary = extraLibrary.str(); + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -66,6 +71,12 @@ class ReifyShapeCalculationsPass // O(#ops in the program) ideally. OwningOpRef library = parseSourceString(getAbstractInterpLibrary(), context); + if (!extraLibrary.empty()) + if (failed(mlir::torch::Torch::loadExtraLibrary(extraLibrary, library))) { + emitError(module->getLoc(), + "Failed to load extra-library file at " + extraLibrary); + return signalPassFailure(); + } // Walk all the operations, and if we have a shape function, wrap the op // in a `torch.shape.calculate` op. @@ -84,6 +95,6 @@ class ReifyShapeCalculationsPass } // namespace std::unique_ptr> -mlir::torch::Torch::createReifyShapeCalculationsPass() { - return std::make_unique(); +mlir::torch::Torch::createReifyShapeCalculationsPass(StringRef extraLibrary) { + return std::make_unique(extraLibrary); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d7fdf9481d5f..f84b1a2ea02e 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -203,7 +203,8 @@ bool Torch::isViewLikeOp(Operation *op) { AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, - AtenNarrowOp, AtenToDeviceOp>(op); + AtenNarrowOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, + PrimsViewOfOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, @@ -255,3 +256,69 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { } return updatedShape; } + +// Helper function to squeeze the input tensor at given dim. +// Return the squeezed tensor or failure. +FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, + Location loc, int64_t dim, Value input) { + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure(loc, "input tensor must have size"); + } + SmallVector inputShape{inputType.getSizes()}; + unsigned inputRank = inputShape.size(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure( + op, "dimension to be squeezed is an invalid dim"); + } + inputShape.erase(inputShape.begin() + dim); + Type squeezedType = + inputType.getWithSizesAndDtype(inputShape, inputType.getOptionalDtype()); + + Value cstDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + // Adding a check to verify if the dimension to be squeezed has size 1 or not. + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dimSize = rewriter.create(loc, input, cstDim); + Value cmp = rewriter.create(loc, dimSize, cstOne); + rewriter.create( + loc, cmp, + "squeeze operation possible for dim only when input_shape[dim] == 1."); + + Value result = + rewriter.create(loc, squeezedType, input, cstDim); + return result; +} + +// Helper function to unsqueeze the input tensor at given dim. +// Return the unsqueezed tensor or failure. +FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, + Operation *op, Value input, Value dim) { + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure(op, "input tensor must have size"); + } + + SmallVector unsqueezedShape; + ArrayRef inputShape = inputType.getSizes(); + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + int64_t dimInt = 0; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, unsqueezedRank); + if (!isValidDim(dimInt, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + unsqueezedShape.append(inputShape.begin(), inputShape.end()); + unsqueezedShape.insert(unsqueezedShape.begin() + dimInt, 1); + } else { + unsqueezedShape.resize(unsqueezedRank, kUnknownSize); + } + Type unsqueezedType = inputType.getWithSizesAndDtype( + unsqueezedShape, inputType.getOptionalDtype()); + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, dim); + return unsqueezed; +} diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 79b3d4229468..4d38f4965df2 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -73,5 +73,5 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, value.cast()); } - return builder.create(loc, value, type); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index a5d5f9b7072c..1f7f4e8f8294 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,16 +1,22 @@ -set(LinkedLibs MLIRIR - MLIRPass - MLIRFuncTransforms - TorchMLIRTorchConversionDialect - TorchMLIRTorchDialect - TorchMLIRTorchPasses - TorchMLIRTorchToLinalg - TorchMLIRTorchToTMTensor - TorchMLIRTorchToArith - TorchMLIRTorchToSCF - TorchMLIRTorchConversionToMLProgram - MLIRMemRefTransforms) - +set(LinkedLibs + MLIRFuncTransforms + MLIRIR + MLIRLinalgTransforms + MLIRMemRefTransforms + MLIRPass + MLIRTosaTransforms + MLIRVectorTransforms + TorchMLIRTorchConversionDialect + TorchMLIRTorchConversionToMLProgram + TorchMLIRTorchDialect + TorchMLIRTorchPasses + TorchMLIRTorchToArith + TorchMLIRTorchToLinalg + TorchMLIRTorchToSCF + TorchMLIRTorchToTMTensor + TorchMLIRTorchToTosa + ) + if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs ChloPasses) endif() diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 14d8f360bfe1..51d917329128 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -72,7 +72,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); - pm.addNestedPass(createConvertTorchConversionToMLProgramPass()); + pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); // Clean up any non-canonical code introduced above.. diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 00117d89533f..380d09bf5222 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -77,7 +77,7 @@ class VerifyLinalgOnTensorsBackendContractPass // Tensor operations should go through linalg and the tensor dialect. target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); - target.addDynamicallyLegalDialect(opHasLegalTypes); + target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); target.addDynamicallyLegalDialect(opHasLegalTypes); diff --git a/python/test/compile_api/backend_legal_ops.py b/python/test/compile_api/backend_legal_ops.py index fe5e8abea7fa..98c034930243 100644 --- a/python/test/compile_api/backend_legal_ops.py +++ b/python/test/compile_api/backend_legal_ops.py @@ -18,6 +18,6 @@ def forward(self, x, y, z): example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] print(torch_mlir.compile(AddmmModule(), example_args, - output_type="torch", backend_legal_ops=["torch.aten.addmm"])) + output_type="torch", backend_legal_ops=["aten.addmm"])) # CHECK-LABEL: @forward # CHECK: torch.aten.addmm diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 443512a6d54a..2d8b9e8822b5 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional, Sequence, Union, List, Dict, Tuple +from typing import Optional, Sequence, Union, List, Dict, Tuple, Callable, Iterable from enum import Enum import sys from io import StringIO +import tempfile from torch._functorch.compile_utils import strip_overloads import torch @@ -15,6 +16,7 @@ from torch_mlir.passmanager import PassManager from .compiler_utils import run_pipeline_with_repro_report from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder +from torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator import generate_library class OutputType(Enum): @@ -240,8 +242,8 @@ def _get_for_tracing( # ops in the backend contract, and move these lists somewhere deeper in the # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { - OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], - OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], + OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], + OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ], OutputType.STABLEHLO: [], } @@ -252,6 +254,7 @@ def compile(model: torch.nn.Module, use_tracing: bool = False, ignore_traced_shapes=False, backend_legal_ops: Optional[Sequence[str]] = None, + extra_library: Iterable[Callable] = [], verbose: bool = False): """Convert a PyTorch model to MLIR. @@ -277,12 +280,28 @@ def compile(model: torch.nn.Module, backend_legal_ops: A list of ops that should be considered legal for the backend. An op that is considered legal will not be decomposed. This option is only valid with the `"torch"` output type. + extra_library: List of abstract interpretation functions to splice + into the abstract interpretation library. See + `docs/adding_abstract_interpretation_functions.md` for more info + on the format the functions should have. verbose: If true, print extra information about the conversion. Returns: An MLIR module that contains the converted model in the specified output type. """ + extra_library_file_name = "" + if len(extra_library) != 0: + extra_library_dict = {} + for library_func in extra_library: + extra_library_dict[library_func.__name__] = library_func + mlir_library = generate_library(extra_library_dict) + + extra_library_file_name = \ + tempfile.gettempdir() + "/custom_op_extra_library.mlir" + with open(extra_library_file_name, "w") as f: + f.write(mlir_library) + output_type = OutputType.get(output_type) example_args = ExampleArgs.get(example_args) if ignore_traced_shapes and not use_tracing: @@ -367,7 +386,8 @@ def compile(model: torch.nn.Module, if output_type == OutputType.RAW: return mb.module - option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \ + " extra-library=" + extra_library_file_name + "}" run_pipeline_with_repro_report( mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index c275c0b2b1df..296c1caca99e 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -65,6 +65,8 @@ def run_pipeline_with_repro_report(module, {description} failed with the following diagnostics: {sys.stderr.getvalue()} + python exception: {e} + For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 1cbb07262dd1..3e52c20c229a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -56,5 +56,12 @@ std::vector compute_shape_bucketize( return {Shape(dtype, self.sizes().vec())}; } +std::vector compute_shape_copy( + const at::Tensor& self, + const at::Tensor& src, + bool non_blocking) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 2a175f5a031c..bb60facda8d7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -18,6 +18,7 @@ # Shape Functions # ============================================================================== +# TODO: upstream this def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int): assert len(weight) == 2 assert len(indices) == 1 @@ -49,6 +50,9 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[i def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇atan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇tanh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -302,6 +306,14 @@ def aten〇var_mean〇correction〡shape(self: List[int], dim: Optional[List[int out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) return out, out +def aten〇var_mean〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> Tuple[List[int], List[int]]: + out = upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + return out, out + +def aten〇var_mean〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> Tuple[int, int]: + _, self_dtype = self_rank_dtype + return self_dtype, self_dtype + def aten〇var_mean〡shape(self: List[int], unbiased: bool = True) -> Tuple[List[int], List[int]]: return [], [] @@ -314,17 +326,6 @@ def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) -def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): - dim = upstream_shape_functions.maybe_wrap_dim(dim, len(self)) - out: List[int] = [] - for i, self_dim in enumerate(self): - if i == dim: - if keepdim: - out.append(1) - else: - out.append(self_dim) - return out - @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. @@ -335,15 +336,19 @@ def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. ]) def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: - if dim is None: - return [] - return _reduce_along_dim(self, dim, keepdim) + return upstream_shape_functions.argmax(self, dim, keepdim) + +# TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, +# making it impossible to add support for it using the current design of the shape library. +def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: + assert num_classes != -1, "getting num_classes from tensor contents is not supported" + return self + [num_classes] def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: - return _reduce_along_dim(self, dim, keepdim) + return upstream_shape_functions.argmax(self, dim, keepdim) def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: - reduced_shape = _reduce_along_dim(self, dim, keepdim) + reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: @@ -358,12 +363,20 @@ def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) +def aten〇movedim〇int〡shape(self: List[int], source: int, destination: int) -> List[int]: + return upstream_shape_functions.movedim(self, [source], [destination]) + +def aten〇movedim〇int〡dtype(self_rank_dtype: Tuple[int, int], source: int, destination: int) -> int: + _, self_dtype = self_rank_dtype + return self_dtype + def aten〇transpose〇int〡shape(self: List[int], dim0: int, dim1: int) -> List[int]: return upstream_shape_functions.transpose(self, dim0, dim1) def aten〇t〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.transpose(self, 0, 1) +# TODO: upstream this def aten〇numpy_T〡shape(self: List[int]) -> List[int]: result_shape: List[int] = [] for i in self: @@ -390,22 +403,15 @@ def aten〇addmm〡shape(self: List[int], mat1: List[int], mat2: List[int], beta ErrorInvocation(TensorOfShape(2, 3, 4), TensorOfShape(2, 4)), # RHS is not rank 3. ]) def aten〇bmm〡shape(self: List[int], mat2: List[int]) -> List[int]: - assert len(self) == 3, "bmm only supports 3D tensors" - assert len(mat2) == 3, "bmm only supports 3D tensors" - assert self[0] == mat2[0], "mismatching batch dimension" - assert self[2] == mat2[1], "mismatching contracting dimension" - return [self[0], self[1], mat2[2]] + return upstream_shape_functions.bmm(self, mat2) def aten〇baddbmm〡shape(self: List[int], batch1: List[int], batch2: List[int], beta: float = 1, alpha: float = 1) -> List[int]: - assert len(batch1) == 3, "baddbmm only supports 3D tensors" - assert len(batch2) == 3, "baddbmm only supports 3D tensors" - assert batch1[0] == batch2[0], "mismatching batch dimension" - assert batch1[2] == batch2[1], "mismatching contracting dimension" - return [batch1[0], batch1[1], batch2[2]] + return upstream_shape_functions.bmm(batch1, batch2) def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) +# TODO: upstream this def aten〇repeat〡shape(self: List[int], repeats: List[int]) -> List[int]: assert len(repeats) >= len(self) ndim = len(repeats) @@ -607,6 +613,9 @@ def aten〇randn_like〡shape(self: List[int], dtype: Optional[int] = None, layo def aten〇randint〇low〡shape(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +def aten〇randint〡shape(high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return size + def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -712,6 +721,16 @@ def aten〇squeeze〡shape(self: List[int]) -> List[int]: def aten〇squeeze〇dim〡shape(self: List[int], dim: int) -> List[int]: return upstream_shape_functions.squeeze(self, dim) +def prims〇squeeze〡shape(a: List[int], dimensions: List[int]) -> List[int]: + return upstream_shape_functions.squeeze_dims(a, dimensions) + +def prims〇view_of〡shape(a: List[int]) -> List[int]: + return a + +def prims〇view_of〡dtype(a_rank_dtype: Tuple[int, int]) -> int: + _, a_dtype = a_rank_dtype + return a_dtype + def prim〇NumToTensor〇Scalar〡shape(a: float) -> List[int]: return [] @@ -759,12 +778,11 @@ def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in ErrorInvocation(TensorOfShape(2, 3), 2, dim=100), # `dim` out of bounds. ]) def aten〇topk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]: - assert k <= self[dim], f"k ({k}) is too big for dimension {dim} of size {self[dim]}" - # All lists which represent tensor shapes are expected to be the result - # of a fresh invocation of `AtenSizeOp`, which allocates a new, unaliased - # list. So in-place mutations are ok. - self[dim] = k - return self, self + return upstream_shape_functions.topk(self, k, dim) + +def aten〇topk〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[int, int]: + _, self_dtype = self_rank_dtype + return self_dtype, torch.int64 def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) @@ -781,27 +799,59 @@ def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optio def aten〇_convolution〇deprecated〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) +_convolution_deprecated_kwargs = { + "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], + "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} +@check_dtype_function( + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Same type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.int32), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different type + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), # Different width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bfloat16), TensorOfShape(1, 1, 1, 1, dtype=torch.int32), # Different type and width + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.complex64), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.complex128), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.bool), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float16), TensorOfShape(1, 1, 1, 1, dtype=torch.float32), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs), + ErrorInvocation(TensorOfShape(1, 1, 1, 1, dtype=torch.float32), TensorOfShape(1, 1, 1, 1, dtype=torch.float16), + TensorOfShape(1, dtype=torch.float32), **_convolution_deprecated_kwargs) +]) +def aten〇_convolution〇deprecated〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype == weight_dtype + assert input_dtype not in [torch.bool, torch.float16, torch.complex64, torch.complex128] + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]: return self def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes) -def aten〇convolution_backward_overrideable〡shape(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]: - return upstream_shape_functions.conv_backwards(grad_output, input, weight, None) - def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: - # Torch's symbolic shape analysis is a bit looser about optional - # arguments than we are, so their batch_norm helper function works - # even though the `weight` is not `Optional`. - # Upstream is working to make this more consistent. - # For now, since this function is so trivial, just write it ourselves. - #return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) - return input + return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) +def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]: + return self, self + +def aten〇sort〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1, descending: bool = False) -> Tuple[int, int]: + _, input_dtype = self_rank_dtype + return input_dtype, torch.long + def aten〇narrow〡shape(self: List[int], dim: int, start: int, length: int) -> List[int]: return upstream_shape_functions.slice(self, dim, start, start + length, 1) @@ -842,41 +892,25 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets ErrorInvocation(TensorOfShape(2, 3), LongTensorOfShape(7), None, 1, -100), # Mismatched batch dimension. ]) def aten〇nll_loss_forward〡shape(self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int) -> Tuple[List[int], List[int]]: - # This is taken shamelessly from the meta function in LossNLL.cpp - self_dim = len(self) - target_dim = len(target) - assert 0 < self_dim <= 2 - assert target_dim <= 1 - no_batch_dim = self_dim == 1 and target_dim == 0 - assert no_batch_dim or (self[0] == target[0]) - n_classes = self[-1] - scalar_shape: List[int] = [] - assert weight is None or (len(weight) == 1 and weight[0] == n_classes) - if reduction == 0 and self_dim == 2: - return [self[0]], scalar_shape - else: - return scalar_shape, scalar_shape + return upstream_shape_functions.nll_loss_forward(self, target, weight, reduction) def aten〇nll_loss_backward〡shape(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +# TODO: upstream this def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: if reduction == 0: return upstream_shape_functions.unary(self) return [] +def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: + return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) + @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[int], weight: Optional[List[int]], bias: Optional[List[int]], eps: float) -> Tuple[List[int], List[int], List[int]]: - reduction_shape: List[int] = [] - num_unreduced_dimensions = len(input) - len(normalized_shape) - assert num_unreduced_dimensions >= 0 - for i in range(num_unreduced_dimensions): - reduction_shape.append(input[i]) - for i in range(num_unreduced_dimensions, len(input)): - reduction_shape.append(1) - return input, reduction_shape, reduction_shape + return upstream_shape_functions.native_layer_norm(input, normalized_shape) @check_shape_function([ Invocation(TensorOfShape(2, 3), None, None, None, None, True, 1e-4, 1e-6), # Training basic case. @@ -886,14 +920,7 @@ def aten〇native_layer_norm〡shape(input: List[int], normalized_shape: List[in Invocation(TensorOfShape(2), None, None, None, None, True, 1e-4, 1e-6) # 1D input. ]) def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float) -> Tuple[List[int], List[int], List[int]]: - if training: - if len(input) >= 2: - return input, [input[1]], [input[1]] - else: - return input, [], [] - running_mean_list: List[int] = [0] if running_mean is None else running_mean - running_var_list: List[int] = [0] if running_var is None else running_var - return input, running_mean_list, running_var_list + return upstream_shape_functions.native_batch_norm(input, weight, bias, running_mean, running_var, training) # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. @@ -919,6 +946,7 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) +# TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" broadcasted_shape: List[int] = [] diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py index 717923cb17a2..f25508897fa5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/library_generator.py @@ -5,7 +5,7 @@ import inspect import re -from typing import List, Optional, Union +from typing import List, Optional, Union, Any, Dict import torch @@ -187,25 +187,30 @@ def _verify_signature_matches_registry(f, registry: Registry): atoms = function_name.split("〇") if len(atoms) == 2: atoms += [""] - operator = registry.get_by_triple(tuple(atoms)) + try: + operator = registry.get_by_triple(tuple(atoms)) + except KeyError as e: + raise ValueError(f"Unable to find op {'.'.join(atoms)} in registry") if function_kind == "shape": expected_signature = operator.get_shape_function_signature() elif function_kind == "dtype": expected_signature = operator.get_dtype_function_signature() elif function_kind == "decomposition": expected_signature = operator.get_decomposition_function_signature() + elif function_kind == "has_value_semantics": + expected_signature = operator.get_has_value_semantics_function_signature() else: raise ValueError(f"Invalid Op signature function kind: '{function_kind}'") if signature != expected_signature: raise ValueError(f"Signature mismatch for {f.__name__!r}: expected {expected_signature!r}, got {signature!r}") -def generate_library(globals_) -> str: - """Convert all op functions in `globals()` into MLIR.""" +def generate_library(functions: Dict[str, Any]) -> str: + """Convert all op functions in `functions` into MLIR.""" mb = ModuleBuilder() # We use the registry to ensure that the shape functions are consistent # with the ops. registry = Registry.load() - for k, v in globals_.items(): + for k, v in functions.items(): if "〇" not in k: continue if not hasattr(v, "_not_present_in_registry"): diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py index 550b47802e76..0396df1a0081 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/registry.py @@ -267,6 +267,23 @@ def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: return self._get_function_signature( "decomposition", parameter_decl_builder, ret_decl_builder) + def get_has_value_semantics_function_signature(self): + """Gets the Python function signature for this op's has_value_semantics function. + + While this is technically debug-only output, it is useful to copy-paste + it from the debug dump into the library definitions, as many + ops have extra default arguments and stuff that are tedious to write out + right. + """ + def parameter_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + return "" + + def ret_decl_builder(arg: "SIG_ATTR_TYPE") -> str: + return "None" + + return self._get_function_signature( + "has_value_semantics", parameter_decl_builder, ret_decl_builder) + def __repr__(self): f = io.StringIO() emitter = TextEmitter(f) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 56bb1ac25364..8d36f56459ea 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -255,6 +255,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::atan : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::floor : (Tensor) -> (Tensor)", @@ -322,6 +323,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") @@ -334,6 +336,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") emit("aten::randn : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randn.generator : (int[], Generator?, int?, int?, Device?, bool?) -> (Tensor)") @@ -359,12 +362,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") - emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"), emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)") - emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::flip : (Tensor, int[]) -> (Tensor)") emit( "aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)" @@ -406,6 +407,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)") + emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)") emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)") @@ -422,6 +424,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::var.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor)") emit("aten::var_mean.correction : (Tensor, int[]?, Scalar?, bool) -> (Tensor, Tensor)") emit("aten::var_mean : (Tensor, bool) -> (Tensor, Tensor)") + emit("aten::var_mean.dim : (Tensor, int[]?, bool, bool) -> (Tensor, Tensor)") + emit("aten::nll_loss2d_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") + emit("aten::nll_loss2d_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") @@ -429,6 +434,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") + emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") @@ -457,13 +463,14 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") + emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") - emit("aten::detach : (Tensor) -> (Tensor)") + emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True) emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)") @@ -495,7 +502,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amax : (Tensor, int[], bool) -> (Tensor)") emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) - emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True) + emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") @@ -544,6 +551,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") + emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") @@ -572,6 +581,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::ne.int_list : (int[], int[]) -> (bool)") emit("aten::any.bool : (bool[]) -> (bool)") emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) + emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") # Str ops. emit("aten::add.str : (str, str) -> (str)") @@ -661,7 +671,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prim::layout : (Tensor) -> (int)") emit("prim::TupleIndex : (Any, int) -> (Any)", has_canonicalizer=True) - emit("prim::device : (Tensor) -> (Device)") + emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") @@ -685,6 +695,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::convert_element_type : (Tensor, int) -> (Tensor)") emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") + emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") + emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) # ========================================================================== # `quantized::` namespace. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index b80392a5afed..e0420022d58a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -324,7 +324,6 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, } // Import DenseElementsAttr data. - // TODO: Support bool tensors. // TODO: More import formats in C-API. auto numElements = tensor.numel(); auto tensor_cpu = tensor.cpu().contiguous(); @@ -346,10 +345,15 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor, return mlirDenseElementsAttrDoubleGet( shapedType, numElements, static_cast(tensorData)); break; - case ScalarType::Bool: + case ScalarType::Bool: { + // TODO: The signature of `mlirDenseElementsAttrBoolGet` should be changed + // upstream to take in a `const bool *` rather than a `const int *` to avoid + // the unnecessary copying into an array four times as large. + const int8_t *elements = static_cast(tensorData); + std::vector tensorDataVector(elements, elements + numElements); return mlirDenseElementsAttrBoolGet(shapedType, numElements, - static_cast(tensorData)); - break; + tensorDataVector.data()); + } break; case ScalarType::QInt8: return mlirDenseElementsAttrInt8Get( shapedType, numElements, static_cast(tensorData)); diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index 36b98c305508..5f969def38e5 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -62,7 +62,8 @@ def _get_decomposition_table(): aten.native_group_norm_backward, aten.sigmoid_backward, aten._native_batch_norm_legit, - aten._native_batch_norm_legit_no_training + aten._native_batch_norm_legit_no_training, + aten.squeeze, ]) diff --git a/python/torch_mlir_e2e_test/configs/torchdynamo.py b/python/torch_mlir_e2e_test/configs/torchdynamo.py index 2b16b1b92d52..f22228fc5b2e 100644 --- a/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -2,7 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. - from typing import List import numpy @@ -24,6 +23,7 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: return False return True + @make_simple_dynamo_backend def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): @@ -45,11 +45,8 @@ def _refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, # Torch-MLIR does not support returning an empty tuple. The reason is # that both returning an empty tuple and returning `None` results in MLIR # functions that have as a return type `()`. In other words, there is no - # way of differentiating between the two. Moreover, since Torch-MLIR treats - # inputs as having value semantics, graphs that return nothing are no-ops to - # Torch-MLIR. - if _returns_empty_tuple(fx_graph): - return fx_graph + # way of differentiating between the two. + assert not _returns_empty_tuple(fx_graph), "encountered graph that does not return anything" mlir_module = torch_mlir.compile( fx_graph, example_inputs, output_type="linalg-on-tensors") @@ -92,7 +89,8 @@ def item_symbol_that_clones_inputs(*inputs): result: Trace = [] for item in trace: f = lambda method, *inputs: method(*inputs) - dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend)(f) + torch._dynamo.reset() + dynamo_f = dynamo.optimize(_refbackend_torchdynamo_backend, nopython=True)(f) output = dynamo_f(item_symbol_that_clones_inputs, *item.inputs) result.append( TraceItem(symbol=item.symbol, diff --git a/python/torch_mlir_e2e_test/test_suite/backprop.py b/python/torch_mlir_e2e_test/test_suite/backprop.py index 46d61d0e605c..7caa8a4c1cb8 100644 --- a/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -113,6 +113,39 @@ def ConvolutionBackwardModule2D_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6), tu.rand(2, 2, 2, 2)) +class ConvolutionBackwardModule2DStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4, 64, 64], torch.float32, True), + ([1, 320, 64, 64], torch.float32, True), + ([4, 320, 3, 3], torch.float32, True), + ]) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + output_mask=[True, True, True]) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStatic()) +def ConvolutionBackwardModule2DStatic_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(1, 4, 64, 64), tu.rand(1, 320, 64, 64), + tu.rand(4, 320, 3, 3)) + class ConvolutionBackwardModule2DPadded(torch.nn.Module): @@ -148,6 +181,40 @@ def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils): tu.rand(2, 2, 3, 3)) +class ConvolutionBackwardModule2DStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 4, 4], torch.float32, True), + ([1, 2, 8, 8], torch.float32, True), + ([2, 2, 3, 3], torch.float32, True), + ]) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=1, + output_mask=[True, True, True]) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStrided()) +def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), + tu.rand(2, 2, 3, 3)) + + # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 96b67e09c870..e312ebcb6787 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -660,7 +660,7 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x, y, z): - return torch.stack([x, y, z], 1) + return torch.stack([x, y, z], dim=1) @register_test_case(module_factory=lambda: TensorsStackModule()) @@ -708,7 +708,7 @@ def __init__(self): ([-1, -1, -1], torch.int64, True), ]) def forward(self, x, y, z): - return torch.cat([x, y, z], dim=-2) + return torch.stack([x, y, z], dim=-2) @register_test_case(module_factory=lambda: TensorsStackPromoteDTypeModule()) @@ -744,6 +744,29 @@ def GatherModule_basic(module, tu: TestUtils): # ============================================================================== +class GatherNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, tensor, indices): + return torch.gather(tensor, -1, indices) + + +@register_test_case(module_factory=lambda: GatherNegativeDimModule()) +def GatherNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4), torch.tensor([[[1, 2, 3], [1, 2, 3]]])) + + +# ============================================================================== + + class GatherRandomIndexModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1388,6 +1411,27 @@ def PrimMinIntModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== + +class PrimMaxIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.prim.max(a.size(0), a.size(1)) + + +@register_test_case(module_factory=lambda: PrimMaxIntModule()) +def PrimMaxIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5)) + + # ============================================================================== class NumToTensorIntModule(torch.nn.Module): @@ -3364,6 +3408,100 @@ def SortIntListReverse_basic(module, tu: TestUtils): # ============================================================================== + +class SortTensor(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True) + ]) + def forward(self, input): + return torch.sort(input) + + +@register_test_case(module_factory=lambda: SortTensor()) +def SortTensor_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +class SortTensorInteger(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True) + ]) + def forward(self, input): + return torch.sort(input) + + +@register_test_case(module_factory=lambda: SortTensorInteger()) +def SortTensorInteger_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3)) + + +class SortTensorDescending(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True) + ]) + def forward(self, input): + return torch.sort(input, descending=True) + + +@register_test_case(module_factory=lambda: SortTensorDescending()) +def SortTensorDescending_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +class SortTensorSpecificDimension(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True) + ]) + def forward(self, input): + return torch.sort(input, dim=1) + + +@register_test_case(module_factory=lambda: SortTensorSpecificDimension()) +def SortTensorSpecificDimension_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +class SortTensorNegativeDimension(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True) + ]) + def forward(self, input): + return torch.sort(input, dim=-1) + + +@register_test_case(module_factory=lambda: SortTensorNegativeDimension()) +def SortTensorNegativeDimension_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class BucketizeTensorModule(torch.nn.Module): def __init__(self): super().__init__() @@ -3470,3 +3608,153 @@ def forward(self, x): @register_test_case(module_factory=lambda: AtenFloatScalarModule()) def AtenFloatScalarModule_basic(module, tu: TestUtils): module.forward(tu.randint(high=5)) + + +# ============================================================================== + + +class MoveDimIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.movedim(x, source=1, destination=2) #0, 2, 1 + + +@register_test_case(module_factory=lambda: MoveDimIntModule()) +def MoveDimIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2, 1)) + + +# ============================================================================== + + +class MoveDimIntNegativeIndexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.movedim(x, source=-1, destination=1) + + +@register_test_case(module_factory=lambda: MoveDimIntNegativeIndexModule()) +def MoveDimIntNegativeIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2)) + + +# ============================================================================== + + +class PrimsViewOfModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.prims.view_of(x) + + +@register_test_case(module_factory=lambda: PrimsViewOfModule()) +def PrimsViewOfModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 2)) + + +class PrimsViewOfZeroRankModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([], torch.float32, True)]) + def forward(self, x): + return torch.ops.prims.view_of(x) + + +@register_test_case(module_factory=lambda: PrimsViewOfZeroRankModule()) +def PrimsViewOfZeroRankModule_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +# ============================================================================== + + +class OneHotModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.long, True)]) + def forward(self, x): + return torch.nn.functional.one_hot(x, num_classes=5) + + +@register_test_case(module_factory=lambda: OneHotModule()) +def OneHotModule_basic(module, tu: TestUtils): + module.forward(tu.randint(10, high=5)) + + +# ============================================================================== + + +class ConstantBoolParameterModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bool_tensor = torch.tensor( + [True, False, True, False], dtype=torch.bool) + + @export + @annotate_args([ + None, + ]) + def forward(self): + return self.bool_tensor + + +@register_test_case(module_factory=lambda: ConstantBoolParameterModule()) +def ConstantBoolParameterModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class AtenTopKModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.topk(x, k=50, dim=-1, largest=True, sorted=True) + + +@register_test_case(module_factory=lambda: AtenTopKModule()) +def AtenTopKModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 100)) + + +class AtenTopKSmallestModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.topk(x, k=20, dim=1, largest=False, sorted=True) + + +@register_test_case(module_factory=lambda: AtenTopKSmallestModule()) +def AtenTopKSmallestModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 40, 50)) diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 9755b87359d0..a4aa1e99bd10 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -685,6 +685,24 @@ def NewZerosModuleFalsePinMemory_basic(module, tu: TestUtils): module.forward(tu.randint(2, 3, high=10)) +class NewZerosStaticModuleLayoutStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4], torch.int64, True), + ]) + def forward(self, a): + return a.new_zeros(a.shape) + + +@register_test_case(module_factory=lambda: NewZerosStaticModuleLayoutStrided()) +def NewZerosStaticModuleLayoutStrided_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 4, high=10)) + # ============================================================================== diff --git a/python/torch_mlir_e2e_test/test_suite/conv.py b/python/torch_mlir_e2e_test/test_suite/conv.py index d36c8b75a539..006301b9fc79 100644 --- a/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/python/torch_mlir_e2e_test/test_suite/conv.py @@ -583,6 +583,31 @@ def forward(self, inputVec, weight): def ConvolutionModule2DTransposeStridedStatic_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) +class ConvolutionModule2DTransposeNonUnitOutputPadding(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, inputVec, weight): + return torch.ops.aten.convolution(inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[1, 1], + groups=1) + +@register_test_case(module_factory=lambda: ConvolutionModule2DTransposeNonUnitOutputPadding()) +def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) + class Conv_Transpose2dModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index ba23c52221a4..4c732317ac02 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -933,6 +933,51 @@ def ElementwiseMishModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtanTensorFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.atan(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanTensorFloatModule()) +def ElementwiseAtanTensorFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4)) + + +# ============================================================================== + + +class ElementwiseAtanTensorIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ]) + def forward(self, a): + return torch.atan(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanTensorIntModule()) +def ElementwiseAtanTensorIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, low=1, high=10).type(torch.int32)) + + +# ============================================================================== + + class ElementwiseAtan2TensorFloatModule(torch.nn.Module): def __init__(self): @@ -1246,6 +1291,29 @@ def ElementwisePowTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowTensorStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ([1, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.pow(a, b) + + +@register_test_case(module_factory=lambda: ElementwisePowTensorStaticModule()) +def ElementwisePowTensorStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(1, 1)) + + +# ============================================================================== + + class ElementwisePowTensorBroadcastModule(torch.nn.Module): def __init__(self): @@ -1269,6 +1337,29 @@ def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwisePowTensorBroadcastStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 1], torch.float32, True), + ([3, 4], torch.float32, True), + ]) + def forward(self, a, b): + return torch.pow(a, b) + + +@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastStaticModule()) +def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1), tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/index_select.py b/python/torch_mlir_e2e_test/test_suite/index_select.py index 6a426b45459e..0fdda62a13a0 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -31,6 +31,25 @@ def IndexSelectSingleIdxModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6), torch.tensor([2])) +class IndexSelectNegativeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([4, 5, 6], torch.float32, True), + ([1], torch.int64, True), + ]) + + def forward(self, input, indices): + return torch.index_select(input, -1, indices) + +@register_test_case(module_factory=lambda: IndexSelectNegativeDimModule()) +def IndexSelectNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6), torch.tensor([2])) + + class IndexSelectTwoIdxModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/pooling.py b/python/torch_mlir_e2e_test/test_suite/pooling.py index bce8850eafb3..69073c6ab6c2 100644 --- a/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -124,6 +124,25 @@ def MaxPool2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=-1)) +class MaxPool2dEmptyStrideStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.max_pool2d(x, kernel_size=2, stride=[]) + + +@register_test_case(module_factory=lambda: MaxPool2dEmptyStrideStaticModule()) +def MaxPool2dEmptyStrideStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, low=-1)) + + class MaxPool2dStaticModule(torch.nn.Module): def __init__(self): @@ -146,6 +165,29 @@ def forward(self, x): def MaxPool2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 64, 112, 112)) +class MaxPool2dStaticCeilModeTrueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3], + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + ceil_mode=True) + + @export + @annotate_args([ + None, + ([1, 64, 112, 112], torch.float32, True), + ]) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case(module_factory=lambda: MaxPool2dStaticCeilModeTrueModule()) +def MaxPool2dStaticCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, 112)) + class MaxPool2dCeilModeTrueModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 70f5cef84618..dd2112110f6f 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -697,7 +697,6 @@ def forward(self, a): def NormScalarOptDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) - # ============================================================================== class NormScalarOptDimKeepDimModule(torch.nn.Module): @@ -717,7 +716,6 @@ def forward(self, a): def NormScalarOptDimKeepDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) - # ============================================================================== class ReduceFrobeniusNormModule(torch.nn.Module): def __init__(self) -> None: @@ -727,7 +725,7 @@ def __init__(self) -> None: @annotate_args([ None, ([-1, -1, -1], torch.float32, True), - ]) + ]) def forward(self, a): return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False) @@ -736,6 +734,7 @@ def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) # ============================================================================== + class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -744,7 +743,7 @@ def __init__(self) -> None: @annotate_args([ None, ([-1, -1, -1], torch.float32, True), - ]) + ]) def forward(self, a): return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True) @@ -754,6 +753,42 @@ def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class LinalgVectorNormModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_vector_norm(a, ord=3.0, dim=[0, 1], keepdim=False) + +@register_test_case(module_factory=lambda: LinalgVectorNormModule()) +def LinalgVectorNormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class LinalgVectorNormKeepDimModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_vector_norm(a, ord=3.0, dim=[0, 1], keepdim=True) + +@register_test_case(module_factory=lambda: LinalgVectorNormKeepDimModule()) +def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + class MseLossNoReductionModule(torch.nn.Module): def __init__(self): super().__init__() @@ -764,7 +799,6 @@ def __init__(self): ([-1 , -1], torch.float32, True), ([-1 , -1], torch.float32, True), ]) - def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=0) @@ -772,6 +806,7 @@ def forward(self, x, y): def MseLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4)) +# ============================================================================== class MseLossMeanReductionModule(torch.nn.Module): def __init__(self): @@ -783,7 +818,6 @@ def __init__(self): ([-1 , -1], torch.float32, True), ([-1 , -1], torch.float32, True), ]) - def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=1) @@ -791,6 +825,7 @@ def forward(self, x, y): def MseLossMeanReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4)) +# ============================================================================== class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module): def __init__(self): @@ -802,10 +837,48 @@ def __init__(self): ([-1 , -1], torch.float32, True), ([-1 , -1], torch.float64, True), ]) - def forward(self, x, y): return torch.ops.aten.mse_loss(x, y, reduction=2) @register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule()) def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64)) + +# ============================================================================== + +class CrossEntropyLossModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ([-1, ], torch.int64, True), + ]) + + def forward(self, input, target): + return torch.ops.aten.cross_entropy_loss(input, target) + +@register_test_case(module_factory=lambda: CrossEntropyLossModule()) +def CrossEntropyLossModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.randint(8, high=2)) + + +class CrossEntropyLossNoReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ([-1, ], torch.int64, True), + ]) + + def forward(self, input, target): + return torch.ops.aten.cross_entropy_loss(input, target, reduction=0) + +@register_test_case(module_factory=lambda: CrossEntropyLossNoReductionModule()) +def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.randint(8, high=2)) diff --git a/python/torch_mlir_e2e_test/test_suite/rng.py b/python/torch_mlir_e2e_test/test_suite/rng.py index 14fd9d2dba92..22076e0310f9 100644 --- a/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/python/torch_mlir_e2e_test/test_suite/rng.py @@ -329,6 +329,65 @@ def RandIntLowDtypeModule_basic(module, tu: TestUtils): # ============================================================================== +class RandIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randint(high=1000, size=[1024, 1024]) + mean = torch.mean(a.to(torch.float32)) + return mean + + +@register_test_case(module_factory=lambda: RandIntModule()) +def RandIntModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== + +class RandIntDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], dtype=torch.float64) + mean = torch.mean(a.to(torch.float32)) + return mean + + +@register_test_case(module_factory=lambda: RandIntDtypeModule()) +def RandIntDtypeModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== + +class RandIntPinMemoryModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + a = torch.ops.aten.randint(high=1000, size=[128, 256, 512], pin_memory=False) + mean = torch.mean(a.to(torch.float32)) + return mean + + +@register_test_case(module_factory=lambda: RandIntPinMemoryModule()) +def RandIntPinMemoryModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== class RandnModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/scatter.py index 11fad62faf43..784ea2ac80fb 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -966,3 +966,54 @@ def ScatterReduceIntMeanModuleIncludeSelf(module, tu: TestUtils): tu.randint(5, 8, 6, dtype=torch.int32, high=10)) # ============================================================================== + +class IndexPutImpl2DIndexModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_(input, (index, ), + value, + accumulate=True, + unsafe=False) + + +@register_test_case( + module_factory=lambda: IndexPutImpl2DIndexModule()) +def IndexPutImpl2DIndexModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 7), tu.randint(2, 3, high=3), tu.rand(2, 3, 7)) + +# ============================================================================== + +class IndexPutImplIndexWithNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 4, 5], torch.float32, True), + ([6, 1], torch.int64, True), + ([7], torch.int64, True), + ([2, 3, 6, 7], torch.float32, True), + ]) + def forward(self, input, index1, index2, value): + return torch.ops.aten._index_put_impl_(input, (None, None, index1, index2), + value, + accumulate=True, + unsafe=False) + + +@register_test_case( + module_factory=lambda: IndexPutImplIndexWithNoneModule()) +def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7)) diff --git a/python/torch_mlir_e2e_test/test_suite/slice_like.py b/python/torch_mlir_e2e_test/test_suite/slice_like.py index 1e8566826547..08cb00e191a3 100644 --- a/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -307,6 +307,23 @@ def forward(self, x, src): def SliceScatterZeroDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(6, 8), tu.rand(1, 8)) +class SliceScatterNegativeEndModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, x, src): + return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1) + + +@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule()) +def SliceScatterNegativeEndModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 8), tu.rand(2, 8)) class SliceScatterNegativeDimModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/squeeze.py b/python/torch_mlir_e2e_test/test_suite/squeeze.py index 04bf6e97cbd7..8b7cf957ac78 100644 --- a/python/torch_mlir_e2e_test/test_suite/squeeze.py +++ b/python/torch_mlir_e2e_test/test_suite/squeeze.py @@ -184,3 +184,44 @@ def forward(self, a): module_factory=lambda: SqueezeDimUnitDimModule()) def SqueezeDimModule_unitDim(module, tu: TestUtils): module.forward(tu.rand(1)) + + +# ============================================================================== + + +class PrimsSqueezeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 2, 3, 1, 4], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.prims.squeeze(a, dimensions=[0, 4, 1]) + + +@register_test_case( + module_factory=lambda: PrimsSqueezeModule()) +def PrimsSqueezeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 2, 3, 1, 4)) + + +class PrimsSqueezeEmptyDimensionsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 1, 4], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.prims.squeeze(a, dimensions=[]) + + +@register_test_case( + module_factory=lambda: PrimsSqueezeEmptyDimensionsModule()) +def PrimsSqueezeEmptyDimensionsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 4)) diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 26ac9aabd390..c6398b48ead3 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -1000,3 +1000,44 @@ def forward(self, x): @register_test_case(module_factory=lambda: VarMeanBiasedModule()) def VarMeanBiasedModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarMeanDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var_mean(x, dim=[1]) + + +@register_test_case(module_factory=lambda: VarMeanDimModule()) +def VarMeanDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +class VarMeanDimBiasedModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var_mean(x, dim=[1], unbiased=False, keepdim=True) + + +@register_test_case(module_factory=lambda: VarMeanDimBiasedModule()) +def VarMeanDimBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) diff --git a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py index e658471edebb..6999989a6743 100644 --- a/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/tosa_backends/linalg_on_tensors.py @@ -29,8 +29,8 @@ "tosa-to-linalg-named", # TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them # to arith.constants here before proceeding further. - "tosa-to-tensor", "tosa-to-linalg", + "tosa-to-tensor", "tosa-to-arith", ]) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 65012500d6e1..472215b92fe7 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -c54ce93106ef7d893be87a9f7b0e0bd98724b539 +ccace360e001a6574f4e6657fee919b756765878 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 93c8cf99c5ff..dd0c9c9a4575 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.1.0.dev20230310 +torch==2.1.0.dev20230505 diff --git a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir new file mode 100644 index 000000000000..8ef04d95166e --- /dev/null +++ b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir @@ -0,0 +1,15 @@ +// RUN: torch-mlir-opt %s -convert-torch-conversion-to-mlprogram -split-input-file | FileCheck %s + +module { + func.func private @f0() -> i64 + func.func private @f1() -> i64 + func.func private @f2() -> i64 + func.func private @f3() -> i64 + func.func private @f4() -> i64 + func.func private @f5() -> i64 + func.func private @f6() -> i64 + func.func private @f7() -> i64 +} + +// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor +// CHECK-NOT: @global_seed diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index c6b2d429d838..52936c53b9b1 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -265,8 +265,12 @@ func.func @torch.aten.sqrt.int(%arg0: !torch.int) -> !torch.float { // CHECK: %[[CST_TRUE:.*]] = arith.constant true // CHECK: %[[TRUE:.*]] = torch_c.from_i1 %[[CST_TRUE]] // CHECK: %[[INPUT:.*]] = torch.prim.ListConstruct %[[FALSE]], %[[TRUE]], %[[FALSE]] : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list -// CHECK: %[[CST_RESULT:.*]] = arith.constant true -// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CST_RESULT]] +// CHECK: %[[TMP1:.*]] = torch_c.to_i1 %[[FALSE]] +// CHECK: %[[TMP2:.*]] = torch_c.to_i1 %[[TRUE]] +// CHECK: %[[TMP3:.*]] = torch_c.to_i1 %[[FALSE]] +// CHECK: %[[CMP:.*]] = arith.ori %[[TMP1]], %[[TMP2]] : i1 +// CHECK: %[[CMP_RESULT:.*]] = arith.ori %[[CMP]], %[[TMP3]] : i1 +// CHECK: %[[RESULT:.*]] = torch_c.from_i1 %[[CMP_RESULT]] // CHECK: return %[[RESULT]] : !torch.bool func.func @torch.aten.any.bool() -> !torch.bool { %false = torch.constant.bool false diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/basic.mlir rename to test/Conversion/TorchToStablehlo/basic.mlir diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/elementwise.mlir rename to test/Conversion/TorchToStablehlo/elementwise.mlir diff --git a/test/Conversion/TorchToMhlo/gather.mlir b/test/Conversion/TorchToStablehlo/gather.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/gather.mlir rename to test/Conversion/TorchToStablehlo/gather.mlir diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir similarity index 84% rename from test/Conversion/TorchToMhlo/linear.mlir rename to test/Conversion/TorchToStablehlo/linear.mlir index 628969956684..b9bac97ca6c9 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -45,7 +45,7 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> // CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> @@ -71,7 +71,7 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> @@ -97,7 +97,7 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T9]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> @@ -123,7 +123,7 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> @@ -146,7 +146,7 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T0]], %[[T7]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> @@ -169,7 +169,7 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> // CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T7]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> @@ -240,7 +240,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> // CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> @@ -360,11 +360,12 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]]) -// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> -// CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> -// CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> +// CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> +// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_5]], dims = [0, 1] : tensor<3x3x4x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x2xf32>) -> tensor<1x4x9x9xf32> +// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> +// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,9,9],f32> func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { %true = torch.constant.bool true %none = torch.constant.none @@ -392,11 +393,12 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) -// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> -// CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> -// CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> +// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x4x2xf32> +// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_7]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x2xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_9:.*]] = torch_c.from_builtin_tensor %[[T_8]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_9]] : !torch.vtensor<[1,4,15,15],f32> func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { %true = torch.constant.bool true %none = torch.constant.none @@ -426,13 +428,12 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) -// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> -// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> -// CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> -// CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> +// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x4x2xf32> +// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_7]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 3], [2, 3]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x2xf32>) -> tensor<1x4x16x16xf32> +// CHECK: %[[T_9:.*]] = torch_c.from_builtin_tensor %[[T_8:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> +// CHECK: return %[[T_9]] : !torch.vtensor<[1,4,16,16],f32> func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { %true = torch.constant.bool true %none = torch.constant.none @@ -462,31 +463,32 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32> -// CHECK: %[[IDX_0:.*]] = arith.constant 0 : index -// CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> -// CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 -// CHECK: %[[IDX_1:.*]] = arith.constant 1 : index -// CHECK: %[[T_9:.*]] = tensor.dim %[[T_6]], %[[IDX_1]] : tensor<2x2x3x3xf32> -// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 -// CHECK: %[[IDX_2:.*]] = arith.constant 2 : index -// CHECK: %[[T_11:.*]] = tensor.dim %[[T_6]], %[[IDX_2]] : tensor<2x2x3x3xf32> -// CHECK: %[[T_12:.*]] = arith.index_cast %[[T_11]] : index to i64 -// CHECK: %[[IDX_3:.*]] = arith.constant 3 : index -// CHECK: %[[T_13:.*]] = tensor.dim %[[T_6]], %[[IDX_3]] : tensor<2x2x3x3xf32> -// CHECK: %[[T_14:.*]] = arith.index_cast %[[T_13]] : index to i64 -// CHECK: %[[T_24:.*]] = arith.constant 2 : i64 -// CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 -// CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 -// CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> -// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> -// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> -// CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> -// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> -// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]]) -// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> -// CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> -// CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> +// CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %6, dims = [0, 1] : tensor<3x3x2x2xf32> +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32> +// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64 +// CHECK: %c1 = arith.constant 1 : index +// CHECK: %dim_0 = tensor.dim %[[T_7]], %c1 : tensor<3x3x2x2xf32> +// CHECK: %[[T_9:.*]] = arith.index_cast %dim_0 : index to i64 +// CHECK: %c2 = arith.constant 2 : index +// CHECK: %dim_1 = tensor.dim %[[T_7]], %c2 : tensor<3x3x2x2xf32> +// CHECK: %[[T_10:.*]] = arith.index_cast %dim_1 : index to i64 +// CHECK: %c3 = arith.constant 3 : index +// CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32> +// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64 +// CHECK: %c2_i64 = arith.constant 2 : i64 +// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %c2_i64 : i64 +// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %c2_i64 : i64 +// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %c2_i64, %[[T_12]] : tensor<5xi64> +// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32> +// CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> +// CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> +// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> +// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) +// CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> +// CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> +// CHECK: return %[[T_18]] : !torch.vtensor<[1,4,15,15],f32> func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { %true = torch.constant.bool true %none = torch.constant.none diff --git a/test/Conversion/TorchToMhlo/lit.local.cfg b/test/Conversion/TorchToStablehlo/lit.local.cfg similarity index 100% rename from test/Conversion/TorchToMhlo/lit.local.cfg rename to test/Conversion/TorchToStablehlo/lit.local.cfg diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/pooling.mlir rename to test/Conversion/TorchToStablehlo/pooling.mlir diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir similarity index 100% rename from test/Conversion/TorchToMhlo/view_like.mlir rename to test/Conversion/TorchToStablehlo/view_like.mlir diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 73b2cbeab6e6..e47370e8af79 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1084,6 +1084,19 @@ func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f3 return %0 : !torch.vtensor<[1,12,128,128],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.abs( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[15,15],si64> -> tensor<15x15xi64> +// CHECK: %[[VAL_2:.*]] = "tosa.abs"(%[[VAL_1]]) : (tensor<15x15xi64>) -> tensor<15x15xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<15x15xi64> -> !torch.vtensor<[15,15],si64> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[15,15],si64> +// CHECK: } +func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor<[15,15],si64>{ + %0 = torch.aten.abs %arg0 : !torch.vtensor<[15,15],si64> -> !torch.vtensor<[15,15],si64> + return %0 : !torch.vtensor<[15,15],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.where.self( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>, @@ -1100,3 +1113,24 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> return %0 : !torch.vtensor<[1,12,5,5],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.remainder.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) {shift = 0 : i32} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) {shift = 0 : i32} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32> +// CHECK: } +func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32> + return %0 : !torch.vtensor<[2, 4],f32> +} + diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 7c6dac439a2d..b4f9db5df4ef 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1451,6 +1451,38 @@ func.func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32> return %0 : !torch.tensor<[?,?],f32> } +// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_device( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { +// CHECK-NEXT: %[[INT6:.*]] = torch.constant.int 6 +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none +// CHECK-NEXT: %[[CPU:.*]] = torch.constant.device "cpu" +// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.device %[[ARG]], %[[CPU]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.Device, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32> +// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f32> +func.func @torch.aten.to.dtype_layout$to_device(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> { + %none = torch.constant.none + %device = torch.constant.device "cpu" + %false = torch.constant.bool false + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %device, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.Device, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32> + return %0 : !torch.tensor<[?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_dtype( +// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> { +// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none +// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-NEXT: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.dtype %[[ARG]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16> +// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f16> +func.func @torch.aten.to.dtype_layout$to_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> { + %none = torch.constant.none + %false = torch.constant.bool false + %int5 = torch.constant.int 5 + %0 = torch.aten.to.dtype_layout %arg0, %int5, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16> + return %0 : !torch.tensor<[?,?],f16> +} + // CHECK-LABEL: func.func @torch.aten.ge.float$same_operand( // CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool { // CHECK: %[[TRUE:.*]] = torch.constant.bool true @@ -1929,3 +1961,11 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number %1 = torch.aten.ScalarImplicit %0 : !torch.vtensor<[],si64> -> !torch.number return %1 : !torch.number } + +// CHECK-LABEL: func.func @torch.prims.view_of$fold( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> +func.func @torch.prims.view_of$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { + %0 = torch.prims.view_of %arg0 : !torch.vtensor<[3,4,2],f32> -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops-legal.mlir b/test/Dialect/Torch/decompose-complex-ops-legal.mlir index 261ae8c96ba2..9cf4c3e9babd 100644 --- a/test/Dialect/Torch/decompose-complex-ops-legal.mlir +++ b/test/Dialect/Torch/decompose-complex-ops-legal.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=torch.aten.softmax.int" -split-input-file %s | FileCheck %s +// RUN: torch-mlir-opt -torch-decompose-complex-ops="legal-ops=aten.softmax.int" -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> { diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index 832f5d592635..178db4fa1da6 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -165,3 +165,8 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> + +func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> { + %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> + return %arg2 : !torch.list> +} diff --git a/test/Dialect/Torch/reify-dtype-calculations.mlir b/test/Dialect/Torch/reify-dtype-calculations.mlir index 455dfbbfd07d..265497ddf324 100644 --- a/test/Dialect/Torch/reify-dtype-calculations.mlir +++ b/test/Dialect/Torch/reify-dtype-calculations.mlir @@ -39,15 +39,16 @@ func.func @op_with_dtype_promotion(%arg0: !torch.vtensor, %arg1: !torch.vtensor) // ----- -// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten.convolution( +// CHECK-LABEL: func.func private @__torch_mlir_dtype_fn.aten._convolution.deprecated( // CHECK-LABEL: func.func @op_with_optional_tensor_arg$none( // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[OPTIONAL_TUPLE:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional> -// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten.convolution({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int +// CHECK: {{.*}} = func.call @__torch_mlir_dtype_fn.aten._convolution.deprecated({{.*}}, %[[OPTIONAL_TUPLE]], {{.*}}) : ({{.*}}, !torch.optional>, {{.*}}) -> !torch.int func.func @op_with_optional_tensor_arg$none(%input: !torch.vtensor, %weight: !torch.vtensor, %stride: !torch.list, %padding: !torch.list, %dilation: !torch.list, %transposed: !torch.bool, %output_padding: !torch.list, %groups: !torch.int) -> !torch.vtensor { %bias_none = torch.constant.none - %0 = torch.aten.convolution %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor + %false = torch.constant.bool false + %0 = torch.aten._convolution.deprecated %input, %weight, %bias_none, %stride, %padding, %dilation, %transposed, %output_padding, %groups, %false, %false, %false : !torch.vtensor, !torch.vtensor, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor return %0 : !torch.vtensor } diff --git a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir index 4bcbae30607a..02d343d92e06 100644 --- a/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir +++ b/test/Dialect/Torch/torch-function-to-torch-backend-pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=torch.aten.square,torch.aten.argmax})' -split-input-file %s | FileCheck %s +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-function-to-torch-backend-pipeline{backend-legal-ops=aten.square,aten.argmax})' -split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @torch.aten.square func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { diff --git a/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir new file mode 100644 index 000000000000..9c8a3575494a --- /dev/null +++ b/test/Dialect/Torch/verify-backend-contract-unimplemented-op.mlir @@ -0,0 +1,10 @@ +// RUN: torch-mlir-opt -torch-verify-backend-contract-no-decompositions -split-input-file -verify-diagnostics %s +func.func @forward(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor { + %none = torch.constant.none + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3,5],f32> to !torch.vtensor<*,f32> + %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32> + // expected-error @+1 {{unsupported by backend contract: Unimplemented operator 'an.unimplemented.op'}} + %2 = torch.operator "an.unimplemented.op"(%1, %1, %none) : (!torch.tensor<*,f32>, !torch.tensor<*,f32>, !torch.none) -> !torch.tensor + %3 = torch.copy.to_vtensor %2 : !torch.vtensor + return %3 : !torch.vtensor +} diff --git a/test/python/custom_op_shape_dtype_fn.py b/test/python/custom_op_shape_dtype_fn.py new file mode 100644 index 000000000000..d955ec7a2a9a --- /dev/null +++ b/test/python/custom_op_shape_dtype_fn.py @@ -0,0 +1,71 @@ +import os +import tempfile +from typing import List, Tuple + +import torch +import torch.utils.cpp_extension +import torch_mlir +from torch_mlir_e2e_test.annotations import export, annotate_args + + +# RUN: %PYTHON %s | FileCheck %s + + +def identity(x: torch.Tensor): + return x + + +goofy_lib = torch.library.Library("goofy", "DEF") +goofy_lib.define("identity(Tensor t) -> Tensor") +goofy_lib.impl("identity", identity) + +def goofy〇identity〡shape(t: List[int]) -> List[int]: + return t + +def goofy〇identity〡dtype(t_rank_dtype: Tuple[int, int]) -> int: + t_rank, t_dtype = t_rank_dtype + return t_dtype + +def goofy〇identity〡has_value_semantics() -> None: + return + +extra_library = [ + goofy〇identity〡shape, goofy〇identity〡dtype, goofy〇identity〡has_value_semantics] + +class CustomOpExampleModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + b = 2 * a + return torch.ops.goofy.identity(b) + + +mod = CustomOpExampleModule() +mod.eval() + +module = torch_mlir.compile( + mod, + torch.ones(3, 4), + output_type="torch", + backend_legal_ops=["goofy.identity"], + extra_library=extra_library, +) + +print(module) + +# CHECK: module attributes {torch.debug_module_name = "CustomOpExampleModule"} { +# CHECK: func.func @forward(%{{.*}}: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +# CHECK: %{{.*}} = torch.constant.int 2 +# CHECK: %{{.*}} = torch.aten.mul.Scalar %{{.*}}, %{{.*}} : !torch.vtensor<[3,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> +# CHECK: %{{.*}} = torch.operator "goofy.identity"(%{{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK: return %1 : !torch.vtensor<[3,4],f32> +# CHECK: } +# CHECK: } diff --git a/test/python/importer/jit_ir/ivalue_import/tensors.py b/test/python/importer/jit_ir/ivalue_import/tensors.py index 1c0612bf790f..831c619adc58 100644 --- a/test/python/importer/jit_ir/ivalue_import/tensors.py +++ b/test/python/importer/jit_ir/ivalue_import/tensors.py @@ -20,7 +20,10 @@ def __init__(self): self.ones_i64 = torch.ones(1, dtype=torch.int64) self.ones_f32 = torch.ones(1, dtype=torch.float32) self.ones_f64 = torch.ones(1, dtype=torch.float64) - self.ones_bool = torch.ones(1, dtype=torch.bool) + # Because bools turn anything that is non-zero into `True`, it is + # important to check a series of `True`s and `False`s to make sure the + # actual values are being imported rather than just garbage. + self.bool_ = torch.tensor([True, False, True, False, True, False], dtype=torch.bool) self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16) self.ones_f16 = torch.ones(1, dtype=torch.half) self.ones_ui8 = torch.ones(1, dtype=torch.uint8) @@ -35,7 +38,7 @@ def __init__(self): # CHECK: %[[ONES_I64:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi64>) : !torch.tensor<[1],si64> # CHECK: %[[ONES_F32:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32> # CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64> -# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense : tensor<1xi1>) : !torch.tensor<[1],i1> +# CHECK: %[[BOOL_:.*]] = torch.tensor.literal(dense<[true, false, true, false, true, false]> : tensor<6xi1>) : !torch.tensor<[6],i1> # CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16> # CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16> # CHECK: %[[ONES_UI8:.*]] = torch.tensor.literal(dense<1> : tensor<1xui8>) : !torch.tensor<[1],ui8> @@ -53,7 +56,7 @@ def __init__(self): # CHECK: torch.slot "ones_i64", %[[ONES_I64]] : !torch.tensor<[1],si64> # CHECK: torch.slot "ones_f32", %[[ONES_F32]] : !torch.tensor<[1],f32> # CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64> -# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1> +# CHECK: torch.slot "bool_", %[[BOOL_]] : !torch.tensor<[6],i1> # CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16> # CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16> # CHECK: torch.slot "ones_ui8", %[[ONES_UI8]] : !torch.tensor<[1],ui8> diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index 9bf123480f5e..c2e975fed650 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -28,8 +28,7 @@ int main(int argc, char **argv) { #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::stablehlo::registerAllDialects(registry); -#endif - return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry, - /*preloadDialectsInContext=*/false)); +#endif + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "MLIR modular optimizer driver\n", registry)); } diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 8e5329074d7a..d272086f24e1 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.15.0.dev20230310 +torchvision==0.16.0.dev20230505 diff --git a/utils/bazel/WORKSPACE.bazel b/utils/bazel/WORKSPACE.bazel index d42dd7e33c0e..374de7d39769 100644 --- a/utils/bazel/WORKSPACE.bazel +++ b/utils/bazel/WORKSPACE.bazel @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") http_archive( name = "bazel_skylib", @@ -113,3 +114,14 @@ http_archive( "https://github.com/bazelbuild/buildtools/archive/refs/tags/4.2.2.tar.gz", ], ) + +maybe( + http_archive, + name = "llvm_zstd", + build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD", + sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0", + strip_prefix = "zstd-1.5.2", + urls = [ + "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz", + ], +) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index ec4d1cb90c59..abfd3ea613a3 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -450,6 +450,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@mlir-hlo//:mlir_hlo", "@mlir-hlo//:transforms_passes", + "@mlir-hlo//stablehlo:register", ], )