From 019699c688ffe66b7b453a7b9833fae9dbd25af6 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Wed, 11 May 2022 21:00:06 -0400 Subject: [PATCH] Add static shape for scalar tensors (#833) * Assume zero rank tensors are scalar * Run RefineTypes pass on JIT Graph * Rollback assumption that zero rank tensors are scalar * Set numSizes to -1 for non-ranked tensors * Rename RefineTypes to RefineTupleTypes --- include/torch-mlir-c/TorchTypes.h | 2 ++ lib/CAPI/TorchTypes.cpp | 6 ++++-- .../csrc/base_lazy_backend/mlir_lowering_context.cpp | 5 +++++ .../torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp | 2 +- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 8cff8da860a9..f459960ee542 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -164,6 +164,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNonValueTensor(MlirType t); /// Gets a !torch.tensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). - /// `optionalDtype` is allowed to be null, meaning that no dtype @@ -190,6 +191,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchValueTensor(MlirType t); /// Gets a !torch.vtensor type. /// +/// - `numSizes` having a value of -1 denotes an unranked tensor. /// - `optionalSizes` is allowed to be null, meaning that no size /// information is present (and `numSizes` is ignored in that case). /// - `optionalDtype` is allowed to be null, meaning that no dtype diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 0c67453f3421..6d72e7e1551c 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -199,7 +199,8 @@ MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::NonValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); @@ -231,7 +232,8 @@ MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, const int64_t *optionalSizes, MlirType optionalDtype) { Optional> optionalSizesArrayRef = None; - if (optionalSizes) + // if numSizes == -1, then it is unranked. + if (numSizes > -1) optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); return wrap(Torch::ValueTensorType::get( unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index f939a7a07c54..7faa5e98d17a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" @@ -108,6 +109,10 @@ void TorchMlirLoweringContext::AddParameter( ComputationPtr TorchMlirLoweringContext::Build() { PRINT_FUNCTION(); + // Since we mutated the types of some nodes to insert shape information, we + // must perform this pass to ensure tuples have up to date output types. + torch::jit::RefineTupleTypes(graph_); + // Insert return values into graph. for (torch::jit::Value* output : root_tuple_) { graph_->block()->registerOutput(output); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 3cd4ae264fbc..3da29416efcd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -144,7 +144,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, if (!sizes.rank()) { // Unranked. return getMlirTensorType(context, - /*numSizes=*/0, + /*numSizes=*/-1, /*optionalSizes=*/nullptr, /*optionalDtype=*/ elementType);