diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 5b59d32e08a9..c3a7bc44bf34 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -158,7 +158,7 @@ def lowering_function(self, f): size_t i = 0; {emplace_arguments_str} {emplace_kwarguments} - torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments); + torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)}); return {schema.aten_name}_out; diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 8e7af23b4fba..b91c5bb5c643 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -42,7 +42,7 @@ def main(device): class Model(torch.nn.Module): def __init__(self): super().__init__() - self.fc1 = torch.nn.Linear(5, 5) + self.fc1 = torch.nn.Linear(5, 10) def forward(self, x): out = self.fc1(x) diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 3ec87e678a1e..ed3274bc719a 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -36,9 +36,6 @@ TorchMlirComputation::TorchMlirComputation( const std::shared_ptr& graph) : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), graph_(graph), num_results_(graph_->outputs().size()) { - - // TODO(henrytu): Save parameter shape information. - for (torch::jit::Value* input : graph_->inputs()) { parameter_names_.push_back(input->debugName()); } @@ -144,22 +141,18 @@ void TorchMlirLoweringContext::AddParameter( ComputationPtr TorchMlirLoweringContext::Build() { PRINT_FUNCTION(); + // Insert return values into graph. for (torch::jit::Value* output : root_tuple_) { graph_->block()->registerOutput(output); } - // Create jit::Function from jit::Graph. - c10::QualifiedName name("graph"); - auto cu = std::make_shared(); - // IMPORTANT: We pass in a COPY of the graph into create_function, since it - // may get mutated in the process. - auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy())); - // Generate MLIR. - MlirOperation func_op = - torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn); + MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp( + /*context=*/mlir_context_, + /*function=*/generate_jit_fn().get(), + /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, + /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - // TODO(henrytu): Inject tensor shapes into func_op return std::make_shared(func_op, mlir_context_, graph_); } @@ -224,6 +217,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { TORCH_CHECK( false, "Unhandled scalar type: ", c10::toString(scalar.type())); } + } else { + // Save parameter shape information. + param->setType(torch::jit::TensorType::create( + /*scalar_type=*/data->shape().scalar_type(), + /*device=*/c10::nullopt, + /*sizes=*/c10::VaryingShape(data->shape().sizes()), + /*strides=*/c10::VaryingShape(), + /*requires_grad=*/c10::nullopt)); } it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) @@ -245,6 +246,46 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { return root_tuple_.size() - 1; } +// Sync vector of c10::Argument with type specified from parallel list of +// jit::Value. There must be a 1:1 map between elements of args and values. +std::vector sync_argument_types( + const std::vector& args, + c10::ArrayRef values) { + TORCH_CHECK( + args.size() == values.size(), + "Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ", + args.size(), ":", values.size(), " instead!"); + + std::vector updated_args; + for (unsigned i = 0; i < args.size(); i++) { + updated_args.push_back(args[i].cloneWithType(values[i]->type())); + } + + return updated_args; +} + +std::unique_ptr +TorchMlirLoweringContext::generate_jit_fn() const { + // IMPORTANT: We pass in a COPY of the graph into create_function, since it + // may get mutated in the process. + auto fn = std::make_unique( + c10::QualifiedName("graph"), graph_->copy(), nullptr); + + c10::FunctionSchema schema = fn->getSchema(); + + // When constructing the default schema of a jit::GraphFunction, input and + // output shapes are stripped (via call to unshapedType(...)); however, + // since we want to have shape information in our MLIR, we'll add it back. + std::vector arguments = + sync_argument_types(schema.arguments(), graph_->inputs()); + std::vector returns = + sync_argument_types(schema.returns(), graph_->outputs()); + + fn->setSchema(schema.cloneWithArguments(arguments).cloneWithReturns(returns)); + + return fn; +} + void TorchMlirLoweringContext::RegisterMlirDialects() { // https://reviews.llvm.org/D88162 mlirRegisterAllDialects(mlir_context_); diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h index 76c9123cebb3..9d44cc008c4f 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -122,6 +122,10 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { size_t AddResult(torch::jit::Value* op); + // Creates a jit::Function from the current jit::Graph. Input and output + // type information is patched to include shape. + std::unique_ptr generate_jit_fn() const; + void RegisterMlirDialects(); std::shared_ptr graph_; diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index 12489c1b053e..11c8f1967ef3 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -44,6 +44,76 @@ namespace torch { namespace lazy { +TorchMlirOpVector LowerTorchMlirBuiltin( + std::shared_ptr function, c10::Symbol sym, + const std::vector tensor_types, + const std::vector& arguments, + const std::vector& kwarguments) { + auto builtin = + std::make_shared(sym, at::nullopt); + auto magic_method = std::make_shared("", builtin); + auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); + auto sv = dynamic_cast(ret.get()); + CHECK(sv); + + TorchMlirOpVector results; + if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { + // Op returns multiple values. + const auto tuple_call_result = sv->asTuple({}, *function); + for (const auto& tuple_component : tuple_call_result) { + auto tuple_component_sv = + dynamic_cast(tuple_component.get()); + results.push_back(tuple_component_sv->getValue()); + } + } else { + // Op returns single value. + results.push_back(sv->getValue()); + } + + // Insert known tensor type information. + unsigned tensor_type_idx = 0; + for (jit::Value* value : results) { + if (value->type()->kind() == c10::TypeKind::TensorType) { + TORCH_CHECK( + tensor_type_idx < tensor_types.size(), + "Tensor corresponding to JIT SSA value %", value->debugName(), + " corresponds to result #", tensor_type_idx, ", but we only have ", + tensor_types.size(), " known types!"); + + value->setType(tensor_types[tensor_type_idx++]); + } + } + + // Ensure that we use up all the known tensor type information available. + TORCH_CHECK( + tensor_type_idx == tensor_types.size(), tensor_type_idx, + " known types were injected into jit::Value, but ", tensor_types.size(), + " were provided from lazy::Node!"); + + return results; +} + +TorchMlirOpVector LowerTorchMlirBuiltin( + std::shared_ptr function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector& arguments, + const std::vector& kwarguments) { + std::vector tensor_types; + + // Generate types with fixed tensor shape information. + for (const Shape& shape : result_shapes) { + tensor_types.push_back(torch::jit::TensorType::create( + /*scalar_type=*/shape.scalar_type(), + /*device=*/c10::nullopt, + /*sizes=*/c10::VaryingShape(shape.sizes()), + /*strides=*/c10::VaryingShape(), + /*requires_grad=*/c10::nullopt)); + } + + return LowerTorchMlirBuiltin( + function, sym, tensor_types, arguments, kwarguments); +} + class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { public: TorchMlirNodeLowering( @@ -189,12 +259,20 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { const std::vector& arguments, const std::vector& kwarguments = {}) { return LowerTorchMlirBuiltin( - function_, node->op().op, arguments, kwarguments); + function_, node->op().op, node->shapes(), arguments, kwarguments); } TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const std::vector& arguments, + c10::Symbol sym, const c10::ArrayRef result_shapes, + const std::vector& arguments, const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments); + return LowerTorchMlirBuiltin( + function_, sym, result_shapes, arguments, kwarguments); + } + TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const std::vector types, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments); } TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) { @@ -222,7 +300,7 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { dest_arguments.emplace_back(node->stride()); dest_arguments.emplace_back(node->storage_offset()); TorchMlirOpVector as_strided_out = - LowerBuiltin(at::aten::as_strided, dest_arguments); + LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments); CHECK_EQ(as_strided_out.size(), 1); torch::jit::Value* as_strided = as_strided_out.front(); GenerateCopy(as_strided, loctx()->GetOutputOp(input_op)); @@ -266,7 +344,7 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); arguments.emplace_back(node->dtype()); - return LowerBuiltin(at::aten::to, arguments); + return LowerBuiltin(at::aten::to, node->shapes(), arguments); } TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) { @@ -383,13 +461,16 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { std::vector arguments; arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); arguments.push_back(node->output_size()); - return LowerBuiltin(at::aten::reshape, arguments); + return LowerBuiltin(at::aten::reshape, node->shapes(), arguments); } torch::jit::Value* GenerateClone(torch::jit::Value* val) { std::vector clone_arguments; clone_arguments.emplace_back(val); - TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments); + + // Type of cloned value should be identical to the original one. + TorchMlirOpVector cloned = + LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments); CHECK_EQ(cloned.size(), 1); return cloned.front(); } @@ -398,7 +479,9 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { std::vector arguments; arguments.emplace_back(destination); arguments.emplace_back(source); - LowerBuiltin(at::aten::copy_, arguments); + LowerBuiltin( + at::aten::copy_, c10::ArrayRef({/*shape goes here*/}), + arguments); } torch::jit::Value* GenerateSlice( @@ -410,7 +493,9 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { arguments.emplace_back(start); arguments.emplace_back(end); arguments.emplace_back(step); - TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments); + TorchMlirOpVector selected = LowerBuiltin( + at::aten::slice, c10::ArrayRef({/*shape goes here*/}), + arguments); CHECK_EQ(selected.size(), 1); return selected.front(); } @@ -424,29 +509,5 @@ TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) { "TorchMlirNodeLowering", static_cast(loctx)); } - -TorchMlirOpVector LowerTorchMlirBuiltin( - std::shared_ptr function, c10::Symbol sym, - const std::vector& arguments, - const std::vector& kwarguments) { - auto builtin = - std::make_shared(sym, at::nullopt); - auto magic_method = std::make_shared("", builtin); - auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); - auto sv = dynamic_cast(ret.get()); - CHECK(sv); - if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { - const auto tuple_call_result = sv->asTuple({}, *function); - TorchMlirOpVector tuple_result; - for (const auto& tuple_component : tuple_call_result) { - auto tuple_component_sv = - dynamic_cast(tuple_component.get()); - tuple_result.push_back(tuple_component_sv->getValue()); - } - return tuple_result; - } - return {sv->getValue()}; -} - } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h index 2e7774b251fe..f9e028a5cc15 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h @@ -23,6 +23,7 @@ typedef std::shared_ptr TorchMlirFunction; TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, const std::vector& arguments, const std::vector& kwarguments = {}); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp index dcda400c7cde..7b2f34bef7f3 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.cpp @@ -23,12 +23,13 @@ using namespace torch_mlir; MlirOperation torch_mlir::importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, - std::function getArgAttribute) { + std::function getArgAttribute, + const ImportOptions &importOptions) { // Useful for debugging: // graph->dump(); MlirLocation loc = mlirLocationUnknownGet(context); MlirType functionType = - getFunctionTypeFromSchema(context, function->getSchema()); + getFunctionTypeFromSchema(context, function->getSchema(), importOptions); // Use the function's qualified name from the compilation unit. // This is a stable linkage name that matches Python module lookup // conventions (see compilation unit import in IValueImporter for more details @@ -70,7 +71,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp( }; MlirBlock block = importBlock( context, torch::jit::toGraphFunction(*function).graph()->block(), - createTerminator, inputTypes); + createTerminator, inputTypes, importOptions); mlirRegionAppendOwnedBlock(bodyRegion, block); return func; } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h index 6a5652eb5d2a..3cab1c12c510 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h @@ -12,6 +12,7 @@ #include +#include "import_options.h" #include "node_importer.h" #include "pybind.h" @@ -42,7 +43,8 @@ namespace torch_mlir { TORCH_API MlirOperation importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, std::function getArgAttribute = - [](int) -> MlirAttribute { return {nullptr}; }); + [](int) -> MlirAttribute { return {nullptr}; }, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h new file mode 100644 index 000000000000..4f24af8851ad --- /dev/null +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h @@ -0,0 +1,29 @@ +//===- import_options.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H +#define TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H + +namespace torch_mlir { +// Common import options across importers. We define this as a struct to avoid +// an unstructured proliferation of different kinds of ways to control different +// parts of the import process. +struct ImportOptions { + // If this is set to true, then all tensors in the program can be assumed to + // have value semantics. This can happen, for example, when coming from + // LazyTensorCore since conversion to value semantics has already happened at + // a higher level there before we see the program. For + // calling-convention-impacting decisions, this flag should be interpreted as + // a requirement to use a value-semantic tensor type (!torch.vtensor) in + // signatures. + bool assumeTensorsHaveValueSemantics = false; +}; +} // namespace torch_mlir + +#endif // TORCHMLIRJITIRIMPORTER_CSRC_IMPORT_OPTIONS_H diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index a879d46ad940..cfb553444aff 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -34,15 +34,18 @@ class NodeImporter { public: NodeImporter(MlirContext context) : context(context) {} - void importNode(Node *node, MlirBlock appendToBlock); + void importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions = {}); MlirBlock importBlock( Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt); + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); private: MlirBlock createBlockFor(Block *jitBlock, - c10::optional> blockArgTypes); + c10::optional> blockArgTypes, + const ImportOptions &importOptions = {}); void mapValue(Value *jitValue, MlirValue value); void mapResults(Node *node, MlirOperation operation); MlirValue lookupMappedValue(Value *jitValue); @@ -77,39 +80,39 @@ rearrangeDictConstructInputs(std::vector &inputs) { return rearranged; } -void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { +void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, + const ImportOptions &importOptions) { MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); - MlirOperation operation = - createMlirOperationAtEnd(appendToBlock, opName, loc, - getMlirTypesFromValues(loc, node->outputs()), - t ? t(mappedInputs) : mappedInputs); + MlirOperation operation = createMlirOperationAtEnd( + appendToBlock, opName, loc, + getMlirTypesFromValues(loc, node->outputs(), importOptions), + t ? t(mappedInputs) : mappedInputs); mapResults(node, operation); }; - auto createAndMapNodeWithAttribute = [&](Node *node, - const std::string &opName, - const std::string &attrName, - MlirAttribute attr) { - MlirOperation operation = - createMlirOperationAtEnd(appendToBlock, opName, loc, - getMlirTypesFromValues(loc, node->outputs()), - lookupMappedValues(node->inputs()), - toMlirNamedAttribute(attrName.c_str(), attr)); - mapResults(node, operation); - }; + auto createAndMapNodeWithAttribute = + [&](Node *node, const std::string &opName, const std::string &attrName, + MlirAttribute attr) { + MlirOperation operation = createMlirOperationAtEnd( + appendToBlock, opName, loc, + getMlirTypesFromValues(loc, node->outputs(), importOptions), + lookupMappedValues(node->inputs()), + toMlirNamedAttribute(attrName.c_str(), attr)); + mapResults(node, operation); + }; // Trivial ops with schema. auto maybeSchema = node->maybeSchema(); if (maybeSchema) { - MlirOperation operation = - createOperationFromSchema(appendToBlock, loc, node->schema(), - getMlirTypesFromValues(loc, node->outputs()), - lookupMappedValues(node->inputs())); + MlirOperation operation = createOperationFromSchema( + appendToBlock, loc, node->schema(), + getMlirTypesFromValues(loc, node->outputs(), importOptions), + lookupMappedValues(node->inputs())); mapResults(node, operation); return; } @@ -175,13 +178,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.int", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.float", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute("value", importAttribute(loc, node, c10::attr::value))); } else if (output->type()->cast()) { @@ -199,7 +202,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { } else if (output->type()->cast()) { op = createMlirOperation( "torch.constant.device", loc, - getMlirTypeFromTorchType(loc, output->type()), + getMlirTypeFromTorchType(loc, output->type(), importOptions), toMlirNamedAttribute( "value", mlirStringAttrGet(context, toMlirStringRef(node->s( c10::attr::value))))); @@ -208,16 +211,15 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { const std::string &symName = function->qualname().qualifiedName(); op = createMlirOperation( "func.constant", loc, - getFunctionTypeFromSchema(context, function->getSchema()), + getFunctionTypeFromSchema(context, function->getSchema(), + importOptions), toMlirNamedAttribute( "value", mlirFlatSymbolRefAttrGet(context, toMlirStringRef(symName)))); } else if (output->type()->cast()) { ClassAnnotator dummyAnnotator; - MlirValue listValue = importIValue(node->ival(c10::attr::value), - appendToBlock, - context, - dummyAnnotator); + MlirValue listValue = importIValue( + node->ival(c10::attr::value), appendToBlock, context, dummyAnnotator); mapResults(node, mlirOpResultGetOwner(listValue)); return; // Early return, since `importIValue` already added op to block. } else { @@ -234,7 +236,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { if (kind == c10::prim::Loop) { std::vector resultTypes = - getMlirTypesFromValues(loc, node->outputs()); + getMlirTypesFromValues(loc, node->outputs(), importOptions); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.Loop", loc, resultTypes, lookupMappedValues(node->inputs().slice(0, 2)), @@ -257,13 +259,13 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator)); + importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); return; } if (kind == c10::prim::If) { std::vector resultTypes = - getMlirTypesFromValues(loc, node->outputs()); + getMlirTypesFromValues(loc, node->outputs(), importOptions); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.If", loc, lookupMappedValue(node->input()), resultTypes, mlirRegionCreate(), mlirRegionCreate()); @@ -278,10 +280,10 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { }; mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 0), - importBlock(node->blocks()[0], createTerminator)); + importBlock(node->blocks()[0], createTerminator, c10::nullopt, importOptions)); mlirRegionAppendOwnedBlock( mlirOperationGetRegion(operation, 1), - importBlock(node->blocks()[1], createTerminator)); + importBlock(node->blocks()[1], createTerminator, c10::nullopt, importOptions)); return; } @@ -290,14 +292,14 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { auto methodName = node->s(c10::attr::name); torch::jit::Function *function = classType->findMethod(methodName); MlirType calleeType = - getFunctionTypeFromSchema(context, function->getSchema()); + getFunctionTypeFromSchema(context, function->getSchema(), importOptions); std::vector expectedTypes; for (int i = 0, e = mlirFunctionTypeGetNumInputs(calleeType); i < e; ++i) { expectedTypes.push_back(mlirFunctionTypeGetInput(calleeType, i)); } MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "torch.prim.CallMethod", loc, - getMlirTypesFromValues(loc, node->outputs()), + getMlirTypesFromValues(loc, node->outputs(), importOptions), adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs()), expectedTypes, /*userAllowsRefinement=*/false), @@ -312,11 +314,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { torch::jit::Block *calleeEntryBlock = torch::jit::toGraphFunction(*functionType->function()).graph()->block(); auto expectedTypes = c10::fmap(calleeEntryBlock->inputs(), [&](Value *v) { - return getMlirTypeFromTorchType(loc, v->type()); + return getMlirTypeFromTorchType(loc, v->type(), importOptions); }); MlirOperation operation = createMlirOperationAtEnd( appendToBlock, "func.call_indirect", loc, - getMlirTypesFromValues(loc, node->outputs()), + getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValue(node->input(0)), adjustStaticInformationForValues( appendToBlock, loc, lookupMappedValues(node->inputs().slice(1)), @@ -336,10 +338,11 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { MlirBlock NodeImporter::importBlock( Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes) { - MlirBlock block = createBlockFor(jitBlock, blockArgTypes); + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { + MlirBlock block = createBlockFor(jitBlock, blockArgTypes, importOptions); for (Node *node : jitBlock->nodes()) { - importNode(node, block); + importNode(node, block, importOptions); } Node *returnNode = jitBlock->return_node(); createTerminator(lookupMappedValues(returnNode->inputs()), block); @@ -347,11 +350,12 @@ MlirBlock NodeImporter::importBlock( } MlirBlock NodeImporter::createBlockFor( - Block *jitBlock, c10::optional> blockArgTypes) { + Block *jitBlock, c10::optional> blockArgTypes, + const ImportOptions &importOptions) { Node *paramNode = jitBlock->param_node(); MlirLocation loc = getMlirLocationFromNode(context, paramNode); std::vector paramNodeTypes = - getMlirTypesFromValues(loc, paramNode->outputs()); + getMlirTypesFromValues(loc, paramNode->outputs(), importOptions); if (!blockArgTypes) blockArgTypes = paramNodeTypes; else @@ -402,7 +406,8 @@ NodeImporter::lookupMappedValues(c10::ArrayRef values) { MlirBlock torch_mlir::importBlock(MlirContext context, Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes) { + c10::optional> blockArgTypes, + const ImportOptions &importOptions) { NodeImporter importer(context); - return importer.importBlock(jitBlock, createTerminator, blockArgTypes); + return importer.importBlock(jitBlock, createTerminator, blockArgTypes, importOptions); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h index c4893fdb96d2..42e491e303f3 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.h @@ -10,6 +10,8 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H #define TORCHMLIRJITIRIMPORTER_CSRC_NODE_IMPORTER_H +#include "import_options.h" + #include #include "pybind.h" @@ -39,7 +41,8 @@ using CreateTerminatorFn = MlirBlock importBlock( MlirContext context, torch::jit::Block *jitBlock, CreateTerminatorFn createTerminator, - c10::optional> blockArgTypes = c10::nullopt); + c10::optional> blockArgTypes = c10::nullopt, + const ImportOptions &importOptions = {}); } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 8a69a73a56d9..49e8e5669986 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -115,14 +115,20 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, throw mlir_diagnostic_emitted(); } -MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType) { +MlirType +torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, + const c10::TypePtr &torchType, + const ImportOptions &importOptions) { MlirContext context = mlirLocationGetContext(loc); using c10::TypeKind; auto kind = torchType->kind(); switch (kind) { case TypeKind::TensorType: { auto tensorType = torchType->cast(); + auto getMlirTensorType = importOptions.assumeTensorsHaveValueSemantics + ? torchMlirTorchValueTensorTypeGet + : torchMlirTorchNonValueTensorTypeGet; + // Element type. MlirType elementType = {nullptr}; if (tensorType->scalarType()) { @@ -135,11 +141,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, auto &sizes = tensorType->symbolic_sizes(); if (!sizes.rank()) { // Unranked. - return torchMlirTorchNonValueTensorTypeGet(context, - /*numSizes=*/0, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/ - elementType); + return getMlirTensorType(context, + /*numSizes=*/0, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/ + elementType); } // Ranked with possibly dynamic dims. auto &symbolicShape = tensorType->symbolic_sizes(); @@ -149,10 +155,10 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, auto shapeSymbol = symbolicShape[i]; dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; } - return torchMlirTorchNonValueTensorTypeGet(context, dims.size(), - /*optionalSizes=*/dims.data(), - /*optionalDtype=*/ - elementType); + return getMlirTensorType(context, dims.size(), + /*optionalSizes=*/dims.data(), + /*optionalDtype=*/ + elementType); } case TypeKind::IntType: { return torchMlirTorchIntTypeGet(context); @@ -171,13 +177,15 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, } case TypeKind::OptionalType: { return torchMlirTorchOptionalTypeGet(getMlirTypeFromTorchType( - loc, torchType->cast()->getElementType())); + loc, torchType->cast()->getElementType(), + importOptions)); } case TypeKind::TupleType: { std::vector containedTypes; for (const c10::TypePtr &type : torchType->cast()->containedTypes()) { - containedTypes.push_back(getMlirTypeFromTorchType(loc, type)); + containedTypes.push_back( + getMlirTypeFromTorchType(loc, type, importOptions)); } return torchMlirTorchTupleTypeGet(context, containedTypes.size(), containedTypes.data()); @@ -193,13 +201,14 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, } case TypeKind::ListType: { return torchMlirTorchListTypeGet(getMlirTypeFromTorchType( - loc, torchType->cast()->getElementType())); + loc, torchType->cast()->getElementType(), + importOptions)); } case TypeKind::DictType: { auto dictType = torchType->cast(); return torchMlirTorchDictTypeGet( - getMlirTypeFromTorchType(loc, dictType->getKeyType()), - getMlirTypeFromTorchType(loc, dictType->getValueType())); + getMlirTypeFromTorchType(loc, dictType->getKeyType(), importOptions), + getMlirTypeFromTorchType(loc, dictType->getValueType(), importOptions)); } case TypeKind::NoneType: { return torchMlirTorchNoneTypeGet(context); @@ -234,10 +243,11 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc, MlirType torch_mlir::getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema) { + const c10::FunctionSchema &schema, + const ImportOptions &importOptions) { MlirLocation loc = mlirLocationUnknownGet(context); auto mapType = [&](const c10::TypePtr &torchType) { - MlirType type = getMlirTypeFromTorchType(loc, torchType); + MlirType type = getMlirTypeFromTorchType(loc, torchType, importOptions); if (mlirTypeIsNull(type)) { std::stringstream msg; msg << "unsupported type in function schema: '" @@ -370,10 +380,11 @@ MlirLocation torch_mlir::getMlirLocationFromNode(MlirContext context, std::vector torch_mlir::getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values) { + c10::ArrayRef values, + const ImportOptions &importOptions) { std::vector ret; for (auto value : values) { - MlirType t = getMlirTypeFromTorchType(loc, value->type()); + MlirType t = getMlirTypeFromTorchType(loc, value->type(), importOptions); if (mlirTypeIsNull(t)) throw mlir_diagnostic_emitted("unsupported type"); ret.push_back(t); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h index 328be1291521..7ca5c2b43b56 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.h @@ -10,6 +10,8 @@ #ifndef TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H #define TORCHMLIRJITIRIMPORTER_CSRC_TORCH_TO_MLIR_UTILS_H +#include "import_options.h" + #include #include "pybind.h" @@ -37,14 +39,16 @@ MlirType getMlirTypeForTorchScalarType(MlirLocation loc, /// Maps a torch type to a corresponding MlirType. Returns a null type /// on failure and emits a diagnostic. MlirType getMlirTypeFromTorchType(MlirLocation loc, - const c10::TypePtr &torchType); + const c10::TypePtr &torchType, + const ImportOptions &importOptions = {}); /// Creates a FunctionType suitable for expressing the signature of `schema`. /// /// This can differ from the type inferred from the block of a /// torch::jit::Function due to derefinement and refinement of tensor types. MlirType getFunctionTypeFromSchema(MlirContext context, - const c10::FunctionSchema &schema); + const c10::FunctionSchema &schema, + const ImportOptions &importOptions = {}); /// Creates an appropriate MlirAttribute that holds the same values as `tensor`. MlirAttribute convertTensorToMlirElementsAttr(at::Tensor tensor, @@ -58,7 +62,8 @@ MlirLocation getMlirLocationFromNode(MlirContext context, std::vector getMlirTypesFromValues(MlirLocation loc, - c10::ArrayRef values); + c10::ArrayRef values, + const ImportOptions &importOptions = {}); std::vector adjustStaticInformationForValues( MlirBlock appendToBlock, MlirLocation loc, c10::ArrayRef values,