From 4fa600ea2995d20c9ec09bbd21e54d666c225adc Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Tue, 7 Jun 2022 14:38:50 -0400 Subject: [PATCH] E2E HuggingFace Bert using LTC Backend (#912) * Update native function definitions * Add ops to support bert lowering - Add empty_strided and as_strided - Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy) - Check for composite implicit ops and add device data IR - Also fix codegen for functionalization * Add autogen to CMakeList * Remove PyTorch submodule * Reduced BERT model size * Print Mark Step status in Torch MLIR LTC debug string * Apply fixes to work with latest upstream/main - Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor * Update shape inference functions - Fixed compute_shape_native_batch_norm when mean and var are uninitialized Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels. - Implemented compute_shape_mul - Fixed bug in reshape shape inference error message * Get MLIR backend more consistent with TS backend - Remove LazyNativeFunctions::_unsafe_view from autogen - Blacklist ops to make JIT graph more like output of TS backend - Print graph when SSA value has mismatch of types and results - Remove normalize_index from LazyShapeInference - Fix seeds for LTC example models * Update and clean up shape inference functions - Prune shape inference functions - Add shape inference function for GenerateSlice - Add shape inference function for GenerateCopy Co-authored-by: Henry Tu --- .gitignore | 1 + .gitmodules | 4 - build_tools/autogen_ltc_backend.py | 120 +++++--- build_tools/autogen_ltc_backend.yaml | 78 ++--- examples/ltc_backend_bert.py | 35 ++- examples/ltc_backend_mnist.py | 2 + externals/pytorch | 1 - python/torch_mlir/csrc/CMakeLists.txt | 14 + .../base_lazy_backend/LazyShapeInference.cpp | 289 ++++-------------- .../base_lazy_backend/LazyShapeInference.h | 110 +++---- .../csrc/base_lazy_backend/backend_impl.cpp | 3 +- .../csrc/base_lazy_backend/ir_builder.h | 3 +- .../mlir_lowering_context.cpp | 6 +- .../mlir_native_functions.cpp | 263 ++++++++++++++++ .../csrc/base_lazy_backend/mlir_node.cpp | 14 +- .../csrc/base_lazy_backend/mlir_node.h | 13 +- .../base_lazy_backend/mlir_node_lowering.cpp | 91 +++++- .../base_lazy_backend/ops/device_data.cpp | 41 +++ .../csrc/base_lazy_backend/ops/device_data.h | 48 +++ .../csrc/base_lazy_backend/ops/to_copy.h | 101 ++++++ .../importer/jit_ir/csrc/node_importer.cpp | 2 +- 21 files changed, 824 insertions(+), 415 deletions(-) delete mode 160000 externals/pytorch create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h create mode 100644 python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h diff --git a/.gitignore b/.gitignore index 29de49a46fe8..74ad3b2e9ec1 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ libtorch* /build/ __pycache__ +*.pyc .pytype diff --git a/.gitmodules b/.gitmodules index fddc722e4d89..62a290ea6000 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ [submodule "external/llvm-project"] path = externals/llvm-project url = https://github.com/llvm/llvm-project.git -[submodule "externals/pytorch"] - path = externals/pytorch - url = https://github.com/pytorch/pytorch.git - shallow = true diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 9755790e7673..e1157dd485f2 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -1,9 +1,11 @@ import argparse import hashlib +import importlib import os import subprocess import sys import warnings +from collections import defaultdict from dataclasses import dataclass from pathlib import Path from shutil import which @@ -11,18 +13,16 @@ import yaml -TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve() -TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch") - -sys.path.append(str(TORCH_DIR)) - # PyTorch's LTC backend autogen script +import torchgen import torchgen.dest.lazy_ir import torchgen.gen_lazy_tensor from torchgen.api.lazy import LazyIrSchema from torchgen.gen import get_grouped_native_functions, parse_native_yaml from torchgen.model import NativeFunctionsGroup +TORCH_DIR = Path(importlib.util.find_spec('torch').origin).resolve().parent.parent +TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent def isOptionalCType(arg): return str(type(arg)) == "" @@ -42,20 +42,29 @@ def generate_native_functions( grouped_native_functions = get_grouped_native_functions(native_functions) def get_native_function_name(f): - func = f.func if hasattr(f, "func") else f.functional.func - return str(func.name) + func = f if hasattr(f, "func") else f.functional + return str(func.func.name), func + + def get_opnames(ops): + opnames = defaultdict(set) + for op in ops: + opname = op.split(".")[0] + opnames[opname].add(op) + return opnames - aten_funcs = set(map(get_native_function_name, grouped_native_functions)) + native_functions = dict(map(get_native_function_name, native_functions)) + grouped_native_functions = dict(map(get_native_function_name, grouped_native_functions)) + aten_funcs = get_opnames(set(grouped_native_functions.keys())) with config_path.open() as f: config = yaml.load(f, yaml.CLoader) # List of unsupported ops in LTC autogen because of some error - blacklist = config.get("blacklist", []) + blacklist = set(config.get("blacklist", [])) # List of supported ops that we don't want to do the full codegen for # primarily view ops - supported = config.get("supported", []) + supported = set(config.get("supported", [])) # List of non-native ops to do IR codegen for non_native = config.get("non_native", []) @@ -65,49 +74,54 @@ def get_native_function_name(f): else: cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"] - output = ( - subprocess.check_output( + torch_ops = set( + op[6:] + for op in subprocess.check_output( cmd + [str(torch_ops_file)], encoding="utf-8", ) .strip() .split(os.linesep) ) + torch_opnames = get_opnames(torch_ops) # process ops list - ops = [] - supported_ops = [] - skipped = [] + ops = set() + composite_implicit = set() - for op in output: - op = op[6:] - opname = op.split(".")[0] - - if opname in blacklist or op in blacklist: + for op in torch_ops: + if op not in native_functions: continue - if opname in supported: - supported_ops.append(op) - continue + func = native_functions[op] + base = func.func.name.name.base - if op not in aten_funcs: - skipped.append(op) + if base in blacklist or op in blacklist: + continue + if base in supported or op in supported: continue - ops.append(op) + if func.has_composite_implicit_autograd_kernel and f"{op}_backward" not in torch_ops: + composite_implicit.add(op) + elif func.func.name.name.inplace: + for autogen in func.autogen: + if "functional" in autogen.overload_name: + ops.add(str(autogen)) + else: + ops.add(op) - opnames = sorted(set(ops)) + skipped = set(torch_ops) - ops - supported - composite_implicit # Additional ops to support that are not supported by Torch-MLIR explicitly - supported_ops.extend(config.get("additional_ops", [])) + supported |= set(config.get("additional_ops", [])) with out_file.open("w") as f: yaml.dump( { "backend": "Lazy", "cpp_namespace": "torch::lazy", - "full_codegen": opnames, - "supported": sorted(supported_ops), + "full_codegen": sorted(ops), + "supported": sorted(supported), "non_native": non_native, }, f, @@ -117,10 +131,15 @@ def get_native_function_name(f): dedent( """ + # Composite implicit ops (supported by Torch-MLIR but not differentiable) + {composite_implicit} # Skipped ops (supported by Torch-MLIR but no equivalent native function) + {skipped} """ + ).format( + composite_implicit=os.linesep.join(f"# - {op}" for op in sorted(composite_implicit)), + skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)), ) - + os.linesep.join(f"# - {op}" for op in sorted(skipped)) ) return parsed_yaml, grouped_native_functions @@ -129,11 +148,13 @@ def get_native_function_name(f): @dataclass(frozen=True) class GenMlirLazyIr(torchgen.dest.GenLazyIR): - def lowering_function(self, schema, declaration_only=True): + def lowering_function(self, schema): signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override" - if declaration_only: + if schema.properties.LowerDeclOnly: return f"{signature};" + elif not schema.properties.Lower: + return "" emplace_arguments = [] for arg in schema.positional_args: @@ -213,7 +234,7 @@ def gen_fallback_code(*args, **kwargs): import re sig_re = re.compile( - r"std::vector\s+(?P\w+)\((?P[^\)]+)\)" + r"std::vector\s+(?P\w+)\((?P[^\)]+)\)" ) global_signatures = {} @@ -307,25 +328,30 @@ def main(args): ) assert backend_path.is_dir() + torchgen_path = Path(torchgen.__path__[0]).resolve() + assert torchgen_path.is_dir() + prev_hash = None hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash") if hash_file.exists(): prev_hash = hash_file.read_text().strip() m = hashlib.sha256() - m.update(script_path.read_bytes()) - m.update(config_path.read_bytes()) - m.update(torch_ops_file.read_bytes()) - if native_functions.exists(): - m.update(native_functions.read_bytes()) - - shape_inference_headers = backend_path.joinpath("LazyShapeInference.h") - if shape_inference_headers.exists(): - m.update(shape_inference_headers.read_bytes()) - - shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp") - if shape_inference_defs.exists(): - m.update(shape_inference_defs.read_bytes()) + + # Add file contents to hash + for path in ( + script_path, + config_path, + torch_ops_file, + native_functions, + backend_path.joinpath("LazyShapeInference.h"), + backend_path.joinpath("LazyShapeInference.cpp"), + torchgen_path.joinpath("dest", "lazy_ir.py"), + torchgen_path.joinpath("api", "lazy.py"), + torchgen_path.joinpath("model.py"), + ): + if path.exists(): + m.update(path.read_bytes()) new_hash = m.hexdigest().strip() diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index d3f04dab4216..72361a0c060c 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -1,22 +1,11 @@ blacklist: # List of unsupported ops in LTC autogen because of some error -- arange # Error: Code below assumes there is at least one tensor arg -- contiguous # Error: TODO add support for type BaseType(name=) - empty_like # Error: TODO add support for type BaseType(name=) -- full # Error: Code below assumes there is at least one tensor arg - index.Tensor # Error: TODO not sure if there are other valid types to handle here - index_put # Error: TODO not sure if there are other valid types to handle here - index_put_ # Error: TODO not sure if there are other valid types to handle here - _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here -- ones # Error: Code below assumes there is at least one tensor arg -- ones_like # Error: TODO add support for type BaseType(name=) -- resize_ # Error: TODO add support for type BaseType(name=) - stack # Error: TODO not sure if there are other valid types to handle here -- to.dtype # Error: TODO add support for type BaseType(name=) -- to.other # Error: TODO add support for type BaseType(name=) -- uniform_ # Error: TODO add support for type BaseType(name=) -- zeros # Error: Code below assumes there is at least one tensor arg -- zeros_like # Error: TODO add support for type BaseType(name=) # Additional ops which autogen is supported for but don't compile yet - detach @@ -24,26 +13,34 @@ blacklist: - size - where - copy_ -- _to_copy -- log_softmax # Not inherently differentiable. Needs to be decomposed. -- linear # Not inherently differentiable. Needs to be decomposed. + +# Disabled for consistency with TS backend +- rsub # List of supported ops that we don't want to do the full codegen for # primarily view ops supported: # - bernoulli # - bernoulli_ +- as_strided +- as_strided_ +- _to_copy - cat - clone -- empty +- empty.memory_format +- empty_strided - expand -- fill_ -- native_batch_norm_backward +- fill_.Scalar - permute +- select.int +- slice.Tensor - squeeze +- squeeze.dim - t +- transpose.int - unsqueeze - view +- _unsafe_view additional_ops: # Additional ops to support that are not supported by Torch-MLIR explicitly @@ -53,35 +50,38 @@ additional_ops: # List of non native ops that we only want to do IR node class generation for non_native: - - func: device_data(std::shared_ptr data) -> Tensor - opkind: ltc_device_data - cache_shape: false - - func: scalar(at::Scalar value, at::ScalarType type) -> Tensor + - func: scalar(Scalar value, ScalarType type) -> Tensor opkind: at::prim::Constant - cache_shape: false - - func: expand(Tensor input, std::vector size, bool is_scalar_expand) -> Tensor - - func: view(Tensor input, std::vector output_size) -> Tensor - cache_shape: false - - func: cast(Tensor input, at::ScalarType dtype, optional stype) -> Tensor + properties: + - ShapeCompute + - TreatScalarsAsConstants + - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor + - func: view(Tensor input, int[] output_size) -> Tensor + properties: + - ShapeCompute + - func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor opkind: ltc_cast - cache_shape: false + properties: + - ShapeCompute # View ops only required until proper functionalization pass is introduced into LTC - - func: as_strided_view_update(Tensor target, Tensor input, std::vector size, std::vector stride, int64_t storage_offset) -> Tensor + - func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor opkind: ltc_as_strided_view_update - - func: as_strided(Tensor input, std::vector size, std::vector stride, int64_t storage_offset) -> Tensor - - func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor + - func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor + - func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor opkind: ltc_diagonal_view_update - cache_shape: false - - func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor - - func: narrow_view_update(Tensor input, Tensor source, std::vector base_indices) -> Tensor + properties: + - ShapeCompute + - func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor + - func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor opkind: ltc_narrow_view_update - - func: narrow(Tensor input, std::vector base_indices, std::vector sizes) -> Tensor - - func: permute(Tensor input, std::vector dims) -> Tensor - - func: resize(Tensor input, std::vector size) -> Tensor - - func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor + - func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor + - func: permute(Tensor input, int[] dims) -> Tensor + - func: resize(Tensor input, int[] size) -> Tensor + - func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor opkind: ltc_select_view_update - cache_shape: false - - func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor + properties: + - ShapeCompute + - func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor - func: squeeze(Tensor input, int dim) -> Tensor - func: unsqueeze(Tensor input, int dim) -> Tensor diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index 9278b1105088..d8434f5ef14a 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -19,7 +19,7 @@ from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader from transformers import BertForSequenceClassification, \ - BertTokenizer, AdamW, get_scheduler + BertConfig, BertTokenizer, AdamW, get_scheduler from typing import List @@ -70,7 +70,7 @@ def train(model: BertForSequenceClassification, return losses -def main(device, lower_only): +def main(device, lower_only, full_size): if device in ("TS", "MLIR_EXAMPLE"): import torch._lazy @@ -95,8 +95,24 @@ def main(device, lower_only): train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8) - model = BertForSequenceClassification.from_pretrained('bert-base-cased', - num_labels=2) + if full_size: + model = BertForSequenceClassification.from_pretrained('bert-base-cased', + num_labels=2) + else: + configuration = BertConfig( + vocab_size=28996, + hidden_size=32, + num_hidden_layers=1, + num_attention_heads=2, + intermediate_size=32, + hidden_act='gelu', + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=512, + layer_norm_eps=1.0e-05, + ) + model = BertForSequenceClassification(configuration) + model.to(device) num_epochs = 3 @@ -115,6 +131,8 @@ def main(device, lower_only): if __name__ == "__main__": + torch.manual_seed(0) + parser = argparse.ArgumentParser() parser.add_argument( "-d", @@ -131,5 +149,12 @@ def main(device, lower_only): default=False, help="Only get backend printout -- do not execute computation", ) + parser.add_argument( + "-f", + "--full_size", + action='store_true', + default=False, + help="Use full sized BERT model instead of one with smaller parameterization", + ) args = parser.parse_args() - main(args.device, args.lower_only) + main(args.device, args.lower_only, args.full_size) diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index b91c5bb5c643..7448bbc0b349 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -73,6 +73,8 @@ def forward(self, x): if __name__ == "__main__": + torch.manual_seed(0) + parser = argparse.ArgumentParser() parser.add_argument( "-d", diff --git a/externals/pytorch b/externals/pytorch deleted file mode 160000 index 9f3d6a00a76c..000000000000 --- a/externals/pytorch +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9f3d6a00a76c567d7c046eabc60ae7a578f7bbde diff --git a/python/torch_mlir/csrc/CMakeLists.txt b/python/torch_mlir/csrc/CMakeLists.txt index 2a20ce93cb98..781a87731750 100644 --- a/python/torch_mlir/csrc/CMakeLists.txt +++ b/python/torch_mlir/csrc/CMakeLists.txt @@ -18,6 +18,17 @@ include_directories(BEFORE ) link_directories("${TORCH_INSTALL_PREFIX}/lib") +# Generate Lazy IR Nodes +execute_process( + COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py -f + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} +) +add_custom_target( + generate_ltc_sources + ALL + COMMENT "Generating Lazy Tensor Core IR Nodes" + COMMAND ${Python3_EXECUTABLE} build_tools/autogen_ltc_backend.py + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) add_library(torch_mlir_ltc_backend SHARED base_lazy_backend/backend_impl.cpp @@ -29,7 +40,10 @@ add_library(torch_mlir_ltc_backend SHARED base_lazy_backend/mlir_native_functions.cpp base_lazy_backend/mlir_node.cpp base_lazy_backend/mlir_node_lowering.cpp + base_lazy_backend/ops/device_data.cpp + base_lazy_backend/ops/generic.cpp ) +target_compile_features(torch_mlir_ltc_backend PRIVATE cxx_std_17) add_dependencies(torch_mlir_ltc_backend TorchMLIRJITIRImporter diff --git a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp index d57fca159e3f..b475d576ed29 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp @@ -7,271 +7,98 @@ // //===----------------------------------------------------------------------===// -#include "LazyShapeInference.h" -#include "../utils/exception.h" +#include +#include #include +#include "../utils/exception.h" +#include "LazyShapeInference.h" + namespace torch { namespace lazy { // TODO(henrytu): Upstream these shape inference functions to PyTorch in the future. -// Turns any negative index positive (assuming it's valid) -int64_t normalize_index(int64_t index, unsigned dims) { - return index < 0 ? (int64_t)dims + index : index; -} - -std::vector -compute_shape_dropout(const at::Tensor& input, double p, bool train) { - return {Shape(input.scalar_type(), input.sizes().vec())}; -} - -std::vector compute_shape_layer_norm( - const at::Tensor& input, at::IntArrayRef normalized_shape, - const c10::optional& weight, - const c10::optional& bias, double eps, bool cudnn_enable) { - return {Shape(input.scalar_type(), input.sizes().vec())}; +std::vector +compute_shape_div(const at::Tensor& self, const at::Scalar & other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector -compute_shape_matmul(const at::Tensor& self, const at::Tensor& other) { - std::vector sizes; - - auto self_sizes = self.sizes().vec(); - auto other_sizes = other.sizes().vec(); - - // For tensors with dimensions >2, the leading dimensions are for batch info. - // The last 2 (or 1 in the case of a single dim tensor) dimensions are the - // matrix dimensions themselves, which is checked to ensure the matmul op - // is legal. - // - // Example: - // [1, 2, 3, 4] -> [1, 2] batch dims and [3, 4] matrix - // [1, 4, 5] -> [1] batch dims and [4, 5] matrix - // [4, 5] -> [] batch dims and [4, 5] matrix - // [5] -> [] batch dims and [5] matrix - // - // We'll start by splitting the shapes as described above. - auto partition_shape = [](at::ArrayRef sizes) { - if (sizes.size() <= 2) { - return std::make_pair( - std::vector(), - std::vector(sizes.begin(), sizes.end())); - } else { - std::size_t partition_idx = sizes.size() - 2; - return std::make_pair( - std::vector(sizes.begin(), sizes.begin() + partition_idx), - std::vector(sizes.begin() + partition_idx, sizes.end())); - } - }; - auto [self_batch_sizes, self_matrix_sizes] = partition_shape(self_sizes); - auto [other_batch_sizes, other_matrix_sizes] = partition_shape(other_sizes); - - // Insert batch dimensions. - // The final list of sizes will be based on the tensor w/ more dims. - // Individual dimension sizes are "right justified" as we iterate thru - // to pick the larger dimension between them. - // 0 1 1 3 4 - // 5 1 2 - // --------- - // 0 1 5 3 4 <- Result - int64_t self_size, other_size; - std::size_t num_batch_dim = - std::max(self_batch_sizes.size(), other_batch_sizes.size()); - auto get_batch_dim = [&](std::vector batch_sizes, std::size_t dim) { - long idx = dim - num_batch_dim + batch_sizes.size(); - // Negative index means out of bounds, which defaults to a dim size of 1. - return idx < 0 ? 1 : batch_sizes[idx]; - }; - for (std::size_t i = 0; i < num_batch_dim; i++) { - self_size = get_batch_dim(self_batch_sizes, i); - other_size = get_batch_dim(other_batch_sizes, i); - - TORCH_CHECK( - self_size == 1 || other_size == 1 || self_size == other_size, - "At trailing dimension ", i, ", expected for dimensions ", - "to either match or have one of them equal one, but got ", self_size, - " and ", other_size, " instead!"); - - sizes.push_back(std::max(self_size, other_size)); - } - - // Keep track of the inner dimensions of matmul to validate op is valid. - std::pair inner_sizes; - if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 1) { - // Dot-Product -- scalar output, so no dimensions inserted - inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]); - } else if (self_matrix_sizes.size() == 1 && other_matrix_sizes.size() == 2) { - // Vector-Matrix product (m) @ (m, n) -> (n) - inner_sizes = std::make_pair(self_matrix_sizes[0], other_matrix_sizes[0]); - - sizes.push_back(other_matrix_sizes[1]); - } else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 1) { - // Matrix-Vector product (m, n) @ (n) -> (m) - inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]); - - sizes.push_back(self_matrix_sizes[0]); - } else if (self_matrix_sizes.size() == 2 && other_matrix_sizes.size() == 2) { - // Matrix-Matrix product (m, n) @ (n, o) -> (m, o) - inner_sizes = std::make_pair(self_matrix_sizes[1], other_matrix_sizes[0]); - - sizes.push_back(self_matrix_sizes[0]); - sizes.push_back(other_matrix_sizes[1]); - } else { - // By this time, self_matrix_sizes and other_matrix_sizes should have at - // most 2 dims, so if this is executed something has gone wrong... - TORCH_CHECK(false, "Invalid matmul shape combination!"); - } - - TORCH_CHECK( - inner_sizes.first == inner_sizes.second, "Inner dimension of matrix (", - inner_sizes.first, ") does not ", "match (", inner_sizes.second, ")!"); - - return {Shape(self.scalar_type(), sizes)}; +std::vector +compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_native_batch_norm( +std::vector compute_shape_native_batch_norm( const at::Tensor& input, const c10::optional& weight, const c10::optional& bias, const c10::optional& running_mean, const c10::optional& running_var, bool training, double momentum, double eps) { - std::vector shapes; + std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + + // A separate mean and var needs to be kept for each channel. + TORCH_CHECK( + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + int64_t num_features = input.sizes().vec()[1]; + if (running_mean.has_value()) { shapes.emplace_back( running_mean.value().scalar_type(), running_mean.value().sizes().vec()); - if (running_var.has_value()) { - shapes.emplace_back( - running_var.value().scalar_type(), running_var.value().sizes().vec()); - } - } - return shapes; -} - -std::vector -compute_shape_reshape(const at::Tensor& self, at::IntArrayRef shape) { - // Make a copy of the desired output shape. - std::vector sizes(shape.begin(), shape.end()); - - // Product of all sizes in input shape is the number of entries in tensor. - int64_t num_entries = 1; - for (int64_t i : self.sizes().vec()) { - num_entries *= i; - } - - // Validate the number of entries in the desired shape. If there is a wildcard - // dimension, we need to find it now in order to populate it. - long wildcard_idx = -1; - int64_t num_concrete_entries = 1; - for (std::size_t idx = 0; idx < sizes.size(); idx++) { - if (sizes[idx] != -1) { - num_concrete_entries *= sizes[idx]; - } else { - TORCH_CHECK(wildcard_idx == -1, "only one dimension can be inferred"); - wildcard_idx = idx; - } + } else { + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); } - if (wildcard_idx == -1) { - // No wildcard, the shape should already be known. - TORCH_CHECK( - num_entries == num_concrete_entries, "shape `[", sizes, - "]` is invalid for input of size ", num_concrete_entries); + if (running_var.has_value()) { + shapes.emplace_back( + running_var.value().scalar_type(), running_var.value().sizes().vec()); } else { - // There is one dimension which is not explicitly declared -- we need to - // infer. - TORCH_CHECK( - num_entries % num_concrete_entries == 0, "shape `[", sizes, - "]` is invalid for input of size ", num_concrete_entries); - - sizes[wildcard_idx] = num_entries / num_concrete_entries; + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); } - - return {Shape(self.scalar_type(), sizes)}; -} - -std::vector compute_shape_rsub( - const at::Tensor& self, const at::Scalar& other, const at::Scalar& alpha) { - // Since other is scalar, the result will match tensor shape. - return {Shape(self.scalar_type(), self.sizes().vec())}; + return shapes; } -std::vector -compute_shape_select(const at::Tensor& self, int64_t dim, int64_t index) { - auto original_shape = self.sizes().vec(); - std::vector sizes(original_shape.begin(), original_shape.end()); +std::vector compute_shape_native_batch_norm_backward( + const at::Tensor& grad_out, const at::Tensor& input, + const c10::optional& weight, + const c10::optional& running_mean, + const c10::optional& running_var, + const c10::optional& save_mean, + const c10::optional& save_invstd, bool train, double eps, + ::std::array output_mask) { + std::vector shapes; + shapes.reserve(3); + shapes.emplace_back(input.scalar_type(), input.sizes().vec()); + // A separate mean and var needs to be kept for each channel. TORCH_CHECK( - dim < (int64_t)sizes.size(), "Dimension ", dim, - " is out of bounds for tensor with ", sizes.size(), " dimensions!"); - TORCH_CHECK( - index < sizes[dim], "Index ", index, - " is out of bounds for dimension of size ", sizes[dim]); - sizes.erase(sizes.begin() + dim); - - return {Shape(self.scalar_type(), sizes)}; -} + input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); + int64_t num_features = input.sizes().vec()[1]; + + // `weight` and `bias` are vectors of length C (number of channels)` + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + shapes.emplace_back( + at::get_default_dtype_as_scalartype(), + std::vector{num_features}); -std::vector compute_shape_slice( - const at::Tensor& self, int64_t dim, c10::optional start, - c10::optional end, int64_t step) { - auto original_shape = self.sizes().vec(); - std::vector sizes(original_shape.begin(), original_shape.end()); - - int64_t dim_size = sizes[dim]; - - // Index may be negative, so we must normalize it. - int64_t start_norm = normalize_index(start.value(), dim_size); - int64_t end_norm = normalize_index(end.value(), dim_size); - - if (start_norm >= end_norm || start_norm >= dim_size || end_norm <= 0) { - // Slice is out of bounds, nothing in range. - sizes[dim] = 0; - } else { - // Clamp upper and lower bound to valid indices. - start_norm = std::max((int64_t)0, start_norm); - end_norm = std::min(dim_size, end_norm); - - // Final size is determined by step and interval size. - sizes[dim] = std::ceil((double)(end_norm - start_norm) / (double)step); - } - - return {Shape(self.scalar_type(), sizes)}; + return shapes; } -std::vector compute_shape_softmax( - const at::Tensor& self, int64_t dim, c10::optional dtype) { +std::vector compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { if (dtype.has_value()) { - return {Shape(dtype.value(), self.sizes().vec())}; + return {Shape(*dtype, size)}; } - return {Shape(self.scalar_type(), self.sizes().vec())}; -} - -std::vector -compute_shape_transpose(const at::Tensor& self, int64_t dim0, int64_t dim1) { - auto original_shape = self.sizes().vec(); - std::vector sizes{original_shape.begin(), original_shape.end()}; - - // Index may be negative, so we must normalize it. We create new variables - // instead of replacing the existing ones so that in the case of an error, - // the original values can be printed out. - int64_t dim0_norm = normalize_index(dim0, sizes.size()); - int64_t dim1_norm = normalize_index(dim1, sizes.size()); - - // Verify dimensions are valid. - TORCH_CHECK( - 0 <= dim0_norm && dim0_norm < (int64_t)sizes.size(), "dim0 has value ", - dim0, ", but there are only ", sizes.size(), " tensor dimensions"); - TORCH_CHECK( - 0 <= dim1_norm && dim1_norm < (int64_t)sizes.size(), "dim1 has value ", - dim1, ", but there are only ", sizes.size(), " tensor dimensions"); - - // Swap shapes at dimensions. - std::swap(sizes[dim0_norm], sizes[dim1_norm]); - - return {Shape(self.scalar_type(), sizes)}; + return {Shape(self.scalar_type(), size)}; } } // namespace lazy diff --git a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h index 9ac5d6a1704b..956f9fc46e2b 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h +++ b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h @@ -22,74 +22,48 @@ namespace lazy { // clang-format off -TORCH_API std::vector compute_shape___and__(const at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride); -TORCH_API std::vector compute_shape__shape_as_tensor(const at::Tensor & self); -TORCH_API std::vector compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size); -TORCH_API std::vector compute_shape_abs(const at::Tensor & self); -TORCH_API std::vector compute_shape_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); -TORCH_API std::vector compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); -TORCH_API std::vector compute_shape_add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); -TORCH_API std::vector compute_shape_batch_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); -TORCH_API std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optional generator); -TORCH_API std::vector compute_shape_bernoulli_(at::Tensor & self, const at::Tensor & p, c10::optional generator); -TORCH_API std::vector compute_shape_bernoulli_(at::Tensor & self, double p, c10::optional generator); -TORCH_API std::vector compute_shape_bincount(const at::Tensor & self, const c10::optional & weights, int64_t minlength); -TORCH_API std::vector compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size); -TORCH_API std::vector compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right); -TORCH_API std::vector compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value); -TORCH_API std::vector compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); -TORCH_API std::vector compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); -TORCH_API std::vector compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups); -TORCH_API std::vector compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); -TORCH_API std::vector compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); -TORCH_API std::vector compute_shape_div(const at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_div_(at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_dropout(const at::Tensor & input, double p, bool train); -TORCH_API std::vector compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse); -TORCH_API std::vector compute_shape_expand_as(const at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape_flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim); -TORCH_API std::vector compute_shape_floor_divide(const at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_fmod(const at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_full_like(const at::Tensor & self, const at::Scalar & fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); -TORCH_API std::vector compute_shape_hardswish(const at::Tensor & self); -TORCH_API std::vector compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); -TORCH_API std::vector compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index); -TORCH_API std::vector compute_shape_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps, bool cudnn_enable); -TORCH_API std::vector compute_shape_log_softmax(const at::Tensor & self, int64_t dim, c10::optional dtype); -TORCH_API std::vector compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); -TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); -TORCH_API std::vector compute_shape_masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); -TORCH_API std::vector compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask); -TORCH_API std::vector compute_shape_matmul(const at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape_max(const at::Tensor & self); -TORCH_API std::vector compute_shape_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); -TORCH_API std::vector compute_shape_mean(const at::Tensor & self, c10::optional dtype); -TORCH_API std::vector compute_shape_mul(const at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_mul_(at::Tensor & self, const at::Scalar & other); -TORCH_API std::vector compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps); -TORCH_API std::vector compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps); -TORCH_API std::vector compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); -TORCH_API std::vector compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); -TORCH_API std::vector compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); -TORCH_API std::vector compute_shape_rand_like(const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); -TORCH_API std::vector compute_shape_relu(const at::Tensor & self); -TORCH_API std::vector compute_shape_relu_(at::Tensor & self); -TORCH_API std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats); -TORCH_API std::vector compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape); -TORCH_API std::vector compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); -TORCH_API std::vector compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index); -TORCH_API std::vector compute_shape_slice(const at::Tensor & self, int64_t dim, c10::optional start, c10::optional end, int64_t step); -TORCH_API std::vector compute_shape_softmax(const at::Tensor & self, int64_t dim, c10::optional dtype); -TORCH_API std::vector compute_shape_square(const at::Tensor & self); -TORCH_API std::vector compute_shape_std(const at::Tensor & self, bool unbiased); -TORCH_API std::vector compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); -TORCH_API std::vector compute_shape_sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); -TORCH_API std::vector compute_shape_sum(const at::Tensor & self, c10::optional dtype); -TORCH_API std::vector compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); -TORCH_API std::vector compute_shape_type_as(const at::Tensor & self, const at::Tensor & other); -TORCH_API std::vector compute_shape_var(const at::Tensor & self, bool unbiased); -TORCH_API std::vector compute_shape_zero_(at::Tensor & self); +TORCH_API std::vector compute_shape___and__(const at::Tensor & self, const at::Tensor & other); +TORCH_API std::vector compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride); +TORCH_API std::vector compute_shape_abs(const at::Tensor & self); +TORCH_API std::vector compute_shape_add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); +TORCH_API std::vector compute_shape_arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); +TORCH_API std::vector compute_shape_bernoulli(const at::Tensor & self, c10::optional generator); +TORCH_API std::vector compute_shape_bernoulli_functional(const at::Tensor & self, const at::Tensor & p, c10::optional generator); +TORCH_API std::vector compute_shape_bincount(const at::Tensor & self, const c10::optional & weights, int64_t minlength); +TORCH_API std::vector compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right); +TORCH_API std::vector compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value); +TORCH_API std::vector compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); +TORCH_API std::vector compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups); +TORCH_API std::vector compute_shape_div(const at::Tensor & self, const at::Scalar & other); +TORCH_API std::vector compute_shape_embedding(const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse); +TORCH_API std::vector compute_shape_embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); +TORCH_API std::vector compute_shape_flip(const at::Tensor & self, at::IntArrayRef dims); +TORCH_API std::vector compute_shape_fmod(const at::Tensor & self, const at::Scalar & other); +TORCH_API std::vector compute_shape_hardswish(const at::Tensor & self); +TORCH_API std::vector compute_shape_hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); +TORCH_API std::vector compute_shape_index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index); +TORCH_API std::vector compute_shape_logical_or(const at::Tensor & self, const at::Tensor & other); +TORCH_API std::vector compute_shape_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); +TORCH_API std::vector compute_shape_masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); +TORCH_API std::vector compute_shape_masked_select(const at::Tensor & self, const at::Tensor & mask); +TORCH_API std::vector compute_shape_max(const at::Tensor & self); +TORCH_API std::vector compute_shape_mean(const at::Tensor & self, c10::optional dtype); +TORCH_API std::vector compute_shape_mul(const at::Tensor & self, const at::Scalar & other); +TORCH_API std::vector compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional & weight, const c10::optional & bias, const c10::optional & running_mean, const c10::optional & running_var, bool training, double momentum, double eps); +TORCH_API std::vector compute_shape_native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const c10::optional & weight, const c10::optional & running_mean, const c10::optional & running_var, const c10::optional & save_mean, const c10::optional & save_invstd, bool train, double eps, ::std::array output_mask); +TORCH_API std::vector compute_shape_native_dropout(const at::Tensor & input, double p, c10::optional train); +TORCH_API std::vector compute_shape_native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale); +TORCH_API std::vector compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional & weight, const c10::optional & bias, double eps); +TORCH_API std::vector compute_shape_native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const c10::optional & weight, const c10::optional & bias, ::std::array output_mask); +TORCH_API std::vector compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory); +TORCH_API std::vector compute_shape_relu(const at::Tensor & self); +TORCH_API std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats); +TORCH_API std::vector compute_shape_resize_functional(const at::Tensor & self, at::IntArrayRef size, c10::optional memory_format); +TORCH_API std::vector compute_shape_sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); +TORCH_API std::vector compute_shape_sum(const at::Tensor & self, c10::optional dtype); +TORCH_API std::vector compute_shape_uniform_functional(const at::Tensor & self, double from, double to, c10::optional generator); +TORCH_API std::vector compute_shape_zero_functional(const at::Tensor & self); + // clang-format on diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp index e71b2f0e2287..250c76eafda1 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp @@ -20,6 +20,7 @@ #include "backend_impl.h" #include "ir_builder.h" #include "mlir_lowering_context.h" +#include "ops/device_data.h" namespace torch { namespace lazy { @@ -112,7 +113,7 @@ TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const { if (!device_data_node) { return nullptr; } - return device_data_node->data; + return device_data_node->data(); } at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( diff --git a/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h b/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h index 96e8e7050135..0b1410e37091 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h +++ b/python/torch_mlir/csrc/base_lazy_backend/ir_builder.h @@ -20,6 +20,7 @@ #include "dynamic_ir.h" #include "generated/LazyNonNativeIr.h" #include "mlir_node.h" +#include "ops/device_data.h" #include "ops/generic.h" // This file contains the TorchMlir IrBuilder @@ -35,7 +36,7 @@ struct TorchMlirIrBuilder : IrBuilder { NodePtr MakeExpand(const Value& input0, const std::vector& size, const bool& is_scalar_expand) const override { return MakeNode(input0, size, is_scalar_expand); } NodePtr MakeView(const Value& input0, const std::vector& output_size) const override { return MakeNode(input0, output_size); } NodePtr MakeCast(const Value& input0, const at::ScalarType& dtype, const c10::optional& stype = c10::nullopt) const override { return MakeNode(input0, dtype, stype); } - NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode(inputs); } + NodePtr MakeTensorList(const OpList& inputs) const override { return MakeNode(inputs); } NodePtr MakeGeneric(const OpKind& op, const OpList& operands, const Shape& shape, const size_t& num_outputs = 1, const hash_t& hash_seed = static_cast(0x5a2d296e9)) const override { return MakeNode(op, operands, shape, num_outputs, hash_seed); } // view ops diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index 7faa5e98d17a..2ebb963bd13e 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -331,8 +331,12 @@ const std::string TorchMlirComputation::to_string() const { ss << "Input/Output Alias Mapping: \n"; for (InputOutputAlias input_output_alias : input_output_aliases_) { ss << "Output: " << input_output_alias.output_index - << " -> Input param: " << input_output_alias.param_number << std::endl; + << " -> Input param: " << input_output_alias.param_number << "\n"; } + ss << "\n"; + + // Mark Step + ss << "In Mark Step: " << (in_mark_step ? "true" : "false") << "\n"; return ss.str(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index 6bd80b7178b6..83f6af571e92 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -31,6 +31,7 @@ #include "../utils/sys_utils.h" #include "LazyShapeInference.h" #include "generated/LazyNativeFunctions.h" +#include "ops/to_copy.h" namespace torch { namespace lazy { @@ -166,6 +167,81 @@ torch::lazy::LazyTensorPtr create_view( return input->CreateViewTensor(std::move(view_info)); } +torch::lazy::ViewInfo CreateAsStridedViewInfo( + const torch::lazy::Shape& input_shape, std::vector size, + std::vector stride, c10::optional storage_offset) { + torch::lazy::Shape result_shape = + torch::lazy::Shape(input_shape.scalar_type(), size); + torch::lazy::AsStridedInfo as_strided_info; + as_strided_info.stride = std::move(stride); + if (storage_offset) { + as_strided_info.offset = *storage_offset; + } + return torch::lazy::ViewInfo( + torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape), + input_shape, std::move(as_strided_info)); +} + +torch::lazy::LazyTensorPtr lazy_narrow( + const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, + int64_t length) { + auto input_shape = input->shape(); + dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); + torch::lazy::Shape narrow_shape = input_shape; + narrow_shape.set_size(dim, length); + + torch::lazy::ViewInfo::Type view_type = + (input_shape.Get().numel() == narrow_shape.numel()) + ? torch::lazy::ViewInfo::Type::kReshape + : torch::lazy::ViewInfo::Type::kNarrow; + torch::lazy::ViewInfo view_info( + view_type, std::move(narrow_shape), input_shape); + view_info.indices[dim] = + torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start); + return input->CreateViewTensor(std::move(view_info)); +} + +torch::lazy::LazyTensorPtr lazy_view( + const torch::lazy::LazyTensorPtr& input, + c10::ArrayRef output_size) { + auto input_shape = input->shape().Get(); + torch::lazy::Shape shape = torch::lazy::Shape( + input_shape.scalar_type(), + at::infer_size(output_size, input_shape.numel())); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape); + return input->CreateViewTensor(std::move(view_info)); +} + +torch::lazy::LazyTensorPtr lazy_select( + const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t index) { + auto shape = input->shape(); + dim = torch::lazy::GetCanonicalDimensionIndex(dim, shape.Get().dim()); + torch::lazy::LazyTensorPtr result = lazy_narrow(input, dim, index, 1); + auto new_dims = torch::lazy::DropDimensions(shape.Get().sizes(), {dim}); + return lazy_view(result, new_dims); +} + +torch::lazy::LazyTensorPtr lazy_slice( + const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, + int64_t end, int64_t step) { + auto input_shape = input->shape(); + dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); + start = + torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start); + end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end); + // PyTorch allows tensor[-1:0] to return a 0-dim tensor. + if (start > end) { + end = start; + } + step = std::min(step, end - start); + + torch::lazy::SelectInfo select = {dim, start, end, step}; + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kSelect, input_shape, select); + return input->CreateViewTensor(std::move(view_info)); +} + } // namespace // at::Tensor LazyNativeFunctions::bernoulli( @@ -194,6 +270,44 @@ torch::lazy::LazyTensorPtr create_view( // // return self; // } +at::Tensor LazyNativeFunctions::as_strided( + const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, + c10::optional storage_offset) { + TORCH_LAZY_FN_COUNTER("lazy::"); + torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); + auto xsize = torch::lazy::ToI64Vector(size); + auto xstride = torch::lazy::ToI64Vector(stride); + if (!torch::lazy::StrideIsSupported(xstride)) { + UNIMPLEMENTED_FUNCTION_ERROR(); + } + return torch::lazy::CreateAtenFromLtcTensor( + self_tensor->CreateViewTensor(CreateAsStridedViewInfo( + self_tensor->shape(), std::move(xsize), std::move(xstride), + storage_offset))); +} + +const at::Tensor& LazyNativeFunctions::as_strided_( + const at::Tensor& self, at::IntArrayRef size, at::IntArrayRef stride, + c10::optional storage_offset) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto self_tensor = torch::lazy::TryGetLtcTensor(self); + auto xsize = torch::lazy::ToI64Vector(size); + auto xstride = torch::lazy::ToI64Vector(stride); + if (!torch::lazy::StrideIsSupported(xstride)) { + UNIMPLEMENTED_FUNCTION_ERROR(); + } + if (self_tensor->data()->view == nullptr) { + self_tensor->SetIrValue(torch::lazy::MakeAsStrided( + self_tensor->GetIrValue(), std::move(xsize), std::move(xstride), + storage_offset.value_or(0))); + } else { + auto input_shape = self_tensor->shape(); + self_tensor->SetSubView(CreateAsStridedViewInfo( + input_shape, std::move(xsize), std::move(xstride), storage_offset)); + } + return self; +} + at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto lazy_tensors = torch::lazy::GetLtcTensors(tensors); @@ -298,6 +412,99 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( return dst; } +at::Tensor LazyNativeFunctions::_to_copy( + const at::Tensor& self, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory, bool non_blocking, + c10::optional memory_format) { + PRINT_FUNCTION(); + auto options = self.options(); + if (dtype) { + // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... + // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it + options = options.dtype(dtype); + } + if (layout) { + options = options.layout(layout); + } + if (memory_format) { + options = options.memory_format(memory_format); + } + if (pin_memory) { + // TODO(whc) can we honor 'pin_memory' in some/all cases? + options = options.pinned_memory(pin_memory); + TORCH_WARN_ONCE("Pinned memory used in lazy _to_copy, check if the " + "behavior is as intended"); + } + + TORCH_LAZY_FN_COUNTER("lazy::"); + auto lazy_self = torch::lazy::TryGetLtcTensor(self); + if (!lazy_self && device && device->type() == c10::kLazy) { + // Case 1: eager->lazy (we create a new lazy tensor) + + auto eager_tensor = + self.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + lazy_self = torch::lazy::GetOrCreateLtcTensor( + eager_tensor, torch::lazy::atenDeviceToBackendDevice(*device)); + return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + } else if (device && device->type() != c10::kLazy) { + // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + options = options.device(device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); + return moved_eager_tensor; + } else if ( + device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { + // Case 3: lazy:0 -> lazy:1 + + // TODO(whc) what do we actually want to do here? + // option 1: materialize, move eager tensor, create new lazy tensor + // - this should be our default, as it is what would happen before we implemented _to_copy + // - actually combines case 1 + case 2 + // option 2: support multiple devices inside one lazy/TS executor (case 4) + // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly + + TORCH_INTERNAL_ASSERT(lazy_self); + auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); + // we move the eager tensor to the 'eager' equivalent of our lazy device + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use + auto eager_device = c10::Device( + torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); + options = options.device(eager_device); + auto moved_eager_tensor = + eager_tensor.to(options, /*non_blocking=*/false, /*copy=*/true); + lazy_self = torch::lazy::GetOrCreateLtcTensor( + moved_eager_tensor, + torch::lazy::atenDeviceToBackendDevice(eager_device)); + return torch::lazy::CreateAtenFromLtcTensor(lazy_self); + + } else { + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) + + // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. + // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to + // convert an eager tensor back to a lazy one inside the torchscript executor + // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument + device = c10::nullopt; + + auto shapes = torch::lazy::compute_shape__to_copy( + self, dtype, layout, device, pin_memory, non_blocking, memory_format); + TORCH_INTERNAL_ASSERT(shapes.size() == 1); + auto node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), dtype, layout, device, pin_memory, + non_blocking, memory_format, std::move(shapes)); + + auto result = + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + std::move(node), lazy_self->GetDevice())); + return result; + } +}; + at::Tensor LazyNativeFunctions::empty( at::IntArrayRef size, c10::optional dtype, c10::optional layout, c10::optional device, @@ -313,6 +520,15 @@ at::Tensor LazyNativeFunctions::empty( return CreateLtcTensor(x_result, GetLtcDevice(device)); } +at::Tensor LazyNativeFunctions::empty_strided( + at::IntArrayRef size, at::IntArrayRef stride, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + TORCH_LAZY_FN_COUNTER("lazy::"); + at::Tensor t = empty(size, dtype, layout, device, pin_memory, c10::nullopt); + return LazyNativeFunctions::as_strided(t, size, stride, /*storage_offset=*/0); +} + at::Tensor LazyNativeFunctions::expand( const at::Tensor& self, at::IntArrayRef size, bool implicit) { TORCH_LAZY_FN_COUNTER("lazy::"); @@ -355,6 +571,23 @@ LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) { self_tensor->CreateViewTensor(std::move(view_info))); } +at::Tensor LazyNativeFunctions::select( + const at::Tensor& self, int64_t dim, int64_t index) { + TORCH_LAZY_FN_COUNTER("lazy::"); + return torch::lazy::CreateAtenFromLtcTensor( + lazy_select(torch::lazy::TryGetLtcTensor(self), dim, index)); +} + +at::Tensor LazyNativeFunctions::slice( + const at::Tensor& self, int64_t dim, c10::optional start, + c10::optional end, int64_t step) { + int64_t start_val = start.has_value() ? start.value() : 0; + int64_t end_val = end.has_value() ? end.value() : INT64_MAX; + TORCH_LAZY_FN_COUNTER("lazy::"); + return torch::lazy::CreateAtenFromLtcTensor(lazy_slice( + torch::lazy::TryGetLtcTensor(self), dim, start_val, end_val, step)); +} + at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) { return squeeze(self, -1); } @@ -390,6 +623,21 @@ at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { self_tensor->CreateViewTensor(std::move(view_info))); } +at::Tensor LazyNativeFunctions::transpose( + const at::Tensor& self, int64_t dim0, int64_t dim1) { + TORCH_LAZY_FN_COUNTER("lazy::"); + torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); + + auto input_shape = self_tensor->shape(); + auto permute_dims = torch::lazy::MakeTransposePermutation( + /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim()); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims); + + return torch::lazy::CreateAtenFromLtcTensor( + self_tensor->CreateViewTensor(std::move(view_info))); +} + at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -418,6 +666,21 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) { self_tensor->CreateViewTensor(std::move(view_info))); } +at::Tensor LazyNativeFunctions::_unsafe_view( + const at::Tensor& self, at::IntArrayRef size) { + TORCH_LAZY_FN_COUNTER("lazy::"); + torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); + + auto input_shape = self_tensor->shape().Get(); + torch::lazy::Shape shape = torch::lazy::Shape( + input_shape.scalar_type(), + at::infer_size(torch::lazy::ToI64Vector(size), input_shape.numel())); + torch::lazy::ViewInfo view_info( + torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape); + return torch::lazy::CreateAtenFromLtcTensor( + self_tensor->CreateViewTensor(std::move(view_info))); +} + void InitializeAtenBindings() {} } // namespace lazy diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index efad5d2e78cd..8e4ad40d261f 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -76,15 +76,23 @@ TorchMlirOpVector TorchMlirNode::Lower( return {}; } -TensorList::TensorList(OpList values) + +OpKind TorchMlirTensorList::ClassOpKind() { + // Note: this OpKind is separate from ltc_ops.h since it would be a circular + // import otherwise + static const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); + return tensor_list_opkind; +} + +TorchMlirTensorList::TorchMlirTensorList(OpList values) : TorchMlirNode( - /*op=*/tensor_list_opkind, + /*op=*/TorchMlirTensorList::ClassOpKind(), /*operands=*/values, /*shapes=*/std::vector(), /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TensorList::Lower( +torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { std::vector tensor_list; CHECK(!operands().empty()); diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index 7f703942b827..858ab0461f26 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -42,6 +42,8 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { TorchMlirNode( OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + ~TorchMlirNode() override = default; + hash_t hash() const override; hash_t shapeHash() const override; @@ -58,9 +60,6 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { hash_t dag_hash_; }; -// Note: this OpKind is separate from ltc_ops.h since it would be a circular -// import otherwise -const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); // TensorList represents an at::TensorList which is a vector[Tensor] but is also // a first-class IValue and can be fed as a single input to a TS program. It is @@ -77,9 +76,11 @@ const OpKind tensor_list_opkind = OpKind::Get("lazy_tensors::tensor_list"); // TODO(whc) once Shape() API is moved to Node base, also make it virtual, and // then implement it as NotImplemented for TensorList, also fixing the assertion // that would fail. -struct TORCH_API TensorList : public TorchMlirNode { - TensorList() = delete; - TensorList(OpList values); +struct TORCH_API TorchMlirTensorList : public TorchMlirNode { + static OpKind ClassOpKind(); + + TorchMlirTensorList() = delete; + TorchMlirTensorList(OpList values); torch::lazy::TorchMlirOpVector Lower( TorchMlirFunction function, diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index dd8a61f7e315..174e6808b070 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -14,6 +14,7 @@ #include "generated/LazyNonNativeIr.h" #include "mlir_lowering_context.h" #include "mlir_node.h" +#include "ops/device_data.h" #include #include @@ -61,8 +62,8 @@ TorchMlirOpVector LowerTorchMlirBuiltin( for (jit::Value* value : results) { if (value->type()->kind() == c10::TypeKind::TensorType) { TORCH_CHECK( - tensor_type_idx < tensor_types.size(), - "Tensor corresponding to JIT SSA value %", value->debugName(), + tensor_type_idx < tensor_types.size(), function->graph()->toString(), + "\nTensor corresponding to JIT SSA value %", value->debugName(), " corresponds to result #", tensor_type_idx, ", but we only have ", tensor_types.size(), " known types!"); @@ -100,6 +101,79 @@ TorchMlirOpVector LowerTorchMlirBuiltin( function, sym, tensor_types, arguments, kwarguments); } +c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { + auto tensor_type = value_type->cast(); + TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); + + return *tensor_type.get(); +} + +c10::optional> +get_tensor_type_shape(c10::TensorType& tensor_type) { + auto& symbolic_shape = tensor_type.symbolic_sizes(); + if (!symbolic_shape.rank()) { + return c10::nullopt; + } + + // Get current tensor shape. + std::vector dims; + dims.resize(*symbolic_shape.rank()); + for (size_t i = 0; i < dims.size(); ++i) { + auto shape_symbol = symbolic_shape[i]; + dims[i] = shape_symbol.is_static() ? shape_symbol.static_size() : -1; + } + + return dims; +} + +std::vector compute_shape_copy(c10::TypePtr value_type) { + c10::TensorType& tensor_type = cast_tensor_type(value_type); + + auto maybe_dims = get_tensor_type_shape(tensor_type); + TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!"); + + auto scalar_type = tensor_type.scalarType(); + TORCH_CHECK( + scalar_type.has_value(), "Unable to copy due to lack of scalar type!"); + return {Shape(scalar_type.value(), maybe_dims.value())}; +} + +std::vector compute_shape_slice( + c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end, + int64_t step) { + c10::TensorType& tensor_type = cast_tensor_type(value_type); + + auto maybe_dims = get_tensor_type_shape(tensor_type); + TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!"); + + std::vector dims = maybe_dims.value(); + int64_t num_dims = dims[dim]; + + // Index may be negative, so we must normalize it. + auto normalize_index = [](int64_t index, unsigned num_dims) { + return index < 0 ? (int64_t)num_dims + index : index; + }; + start = normalize_index(start, num_dims); + end = normalize_index(end, num_dims); + + if (start >= end || start >= num_dims || end <= 0) { + // Slice is out of bounds, nothing in range. + dims[dim] = 0; + } else { + // Clamp upper and lower bound to valid indices. + start = std::max((int64_t)0, start); + end = std::min(num_dims, end); + + // Final size is determined by step and interval size. + dims[dim] = std::ceil((double)(end - start) / (double)step); + } + + auto scalar_type = tensor_type.scalarType(); + TORCH_CHECK( + scalar_type.has_value(), "Unable to slice due to lack of scalar type!"); + return {Shape(scalar_type.value(), dims)}; +} + class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { public: TorchMlirNodeLowering( @@ -213,14 +287,14 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { const torch::lazy::DeviceData* device_data_node = torch::lazy::NodeCast( node, *torch::lazy::ltc_device_data); - auto infoptr = device_data_node->data->info(); + auto infoptr = device_data_node->data()->info(); auto deviceDataInfoPtr = (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; if (GRAPH_DUMP_ENABLED) { LOG(ERROR) << "Lowering device data node, tensor id " << deviceDataInfoPtr->tensor_id << std::endl; } - return {loctx()->GetParameter(device_data_node->data)}; + return {loctx()->GetParameter(device_data_node->data())}; } std::vector arguments; @@ -421,8 +495,8 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { arguments.emplace_back(destination); arguments.emplace_back(source); LowerBuiltin( - at::aten::copy_, c10::ArrayRef({/*shape goes here*/}), - arguments); + at::aten::copy_, + c10::ArrayRef(compute_shape_copy(source->type())), arguments); } torch::jit::Value* GenerateSlice( @@ -434,8 +508,11 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { arguments.emplace_back(start); arguments.emplace_back(end); arguments.emplace_back(step); + TorchMlirOpVector selected = LowerBuiltin( - at::aten::slice, c10::ArrayRef({/*shape goes here*/}), + at::aten::slice, + c10::ArrayRef( + compute_shape_slice(base->type(), dim, start, end, step)), arguments); CHECK_EQ(selected.size(), 1); return selected.front(); diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp new file mode 100644 index 000000000000..ef7e4d189207 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.cpp @@ -0,0 +1,41 @@ +#include + +#include + +#include "device_data.h" + +namespace torch { +namespace lazy { + +DeviceData::DeviceData(std::shared_ptr data) + : TorchMlirNode( + ClassOpKind(), + data->shape(), + /*num_outputs=*/1, + /*hash_seed=*/static_cast(101)), + data_(std::move(data)) {} + +std::string DeviceData::ToString() const { + std::stringstream ss; + ss << TorchMlirNode::ToString() << ", device=" << data_->device(); + return ss.str(); +} + +const DeviceData* DeviceData::Cast(const Node* node) { + return NodeCast(node); +} + +NodePtr DeviceData::Create(std::shared_ptr data) { + NodePtr node = ReuseOrMakeNode(data); + // ReuseOrMakeNode may return a reused node which has the same shape, + // however, we need to replace the old data_ with the new one. + // Ditching the old data_ is safe because tracing is done iteration + // by iteration, and after we lauch the async device execution for the + // previous iteration, data_ in DeviceData nodes are not needed anymore. + DeviceData* device_data = static_cast(node.get()); + device_data->SetData(data); + return node; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h new file mode 100644 index 000000000000..0e3e8d635054 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/device_data.h @@ -0,0 +1,48 @@ +#pragma once + +#include "../mlir_node.h" + +#include +#include + + +namespace torch { +namespace lazy { + +class TORCH_API DeviceData : public TorchMlirNode { + public: + static OpKind ClassOpKind() { + return ltc_device_data; + } + + explicit DeviceData(std::shared_ptr data); + + // A DeviceData node can be reused if the shape matches, + // but we will substitute the actual data_ pointer under + // the hood. + bool CanBeReused(std::shared_ptr data) const { + return data_->shape() == data->shape(); + } + + std::string ToString() const override; + + const std::shared_ptr& data() const { + return data_; + } + + void SetData(std::shared_ptr data) { + data_ = data; + } + + static const DeviceData* Cast(const Node* node); + + // To reuse IR nodes, use this method to create DeviceData nodes + // instead of calling the constructor directly. + static NodePtr Create(std::shared_ptr data); + + private: + std::shared_ptr data_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h new file mode 100644 index 000000000000..311d97f90aa6 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h @@ -0,0 +1,101 @@ +//===- to_copy.h ----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// this file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ops/to_copy.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include "../mlir_node.h" + +namespace torch { +namespace lazy { + + +// This IR was copied from code-generated output, but the entire _to_copy operator +// cannot be trivially code genereated since it is only desirable to capture IR for +// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke +// the aten/eager fallback necessitating directly implementing the right to(device) behavior +class ToCopy : public torch::lazy::TorchMlirNode { + public: + ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) + : torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy), + {self}, std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), + + dtype(dtype), + layout(layout), + device(device), + pin_memory(pin_memory), + non_blocking(non_blocking), + memory_format(memory_format) {} + + std::string ToString() const override { + std::stringstream ss; + ss << torch::lazy::TorchMlirNode::ToString(); + if (dtype.has_value()) { + ss << ", dtype=" << dtype.value(); + } else { + ss << ", dtype=null"; + } + if (layout.has_value()) { + ss << ", layout=" << layout.value(); + } else { + ss << ", layout=null"; + } + if (device.has_value()) { + ss << ", device=" << device.value(); + } else { + ss << ", device=null"; + } + if (pin_memory.has_value()) { + ss << ", pin_memory=" << pin_memory.value(); + } else { + ss << ", pin_memory=null"; + } + ss << ", non_blocking=" << non_blocking; + if (memory_format.has_value()) { + ss << ", memory_format=" << memory_format.value(); + } else { + ss << ", memory_format=null"; + } + return ss.str(); + } + + torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function, + torch::lazy::TorchMlirLoweringContext* loctx) const override { + std::vector arguments; + std::vector kwarguments; + arguments.reserve(1); + kwarguments.reserve(6); + size_t i = 0; + arguments.emplace_back(loctx->GetOutputOp(operand(i++))); + kwarguments.emplace_back("dtype", dtype); + kwarguments.emplace_back("layout", layout); + kwarguments.emplace_back("device", device); + kwarguments.emplace_back("pin_memory", pin_memory); + kwarguments.emplace_back("non_blocking", non_blocking); + kwarguments.emplace_back("memory_format", memory_format); + torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + CHECK_EQ(_to_copy_out.size(), 1); + + return _to_copy_out; + + } + + c10::optional dtype; + c10::optional layout; + c10::optional device; + c10::optional pin_memory; + bool non_blocking; + c10::optional memory_format; +}; +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index a448760d0862..d4870d04a209 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -131,7 +131,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, auto containedTypes = c10::fmap( node->output()->type()->cast()->containedTypes(), [&](const c10::TypePtr &t) { - MlirType type = getMlirTypeFromTorchType(loc, t); + MlirType type = getMlirTypeFromTorchType(loc, t, importOptions); if (mlirTypeIsNull(type)) { throw mlir_diagnostic_emitted(); }