Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LTC Decoupling #815

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ __pycache__
# Autogenerated LTC files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
/python/torch_mlir/csrc/base_lazy_backend/generated

# LTC Example Backend
/examples/ltc_backend/ltc_backend/_EXAMPLE_MLIR_BACKEND.*.so

# Bazel
bazel-*
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "externals/pytorch"]
path = externals/pytorch
url = https://github.com/pytorch/pytorch.git
shallow = true
60 changes: 32 additions & 28 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@
import yaml

TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")

sys.path.append(str(TORCH_DIR.joinpath("tools")))
sys.path.append(str(TORCH_DIR))

# PyTorch's LTC backend autogen script
import codegen.dest.lazy_ir
import codegen.gen_lazy_tensor
from codegen.api.lazy import LazyIrSchema
from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup


def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"


def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
print("Generating Native Functions Yaml")

native_yaml_path = TORCH_DIR.joinpath(
"aten", "src", "ATen", "native", "native_functions.yaml"
)
native_path = TORCH_DIR.joinpath("aten", "src", "ATen", "native")
native_yaml_path = native_path.joinpath("native_functions.yaml")
tags_yaml_path = native_path.joinpath("tags.yaml")

parsed_yaml = parse_native_yaml(native_yaml_path)
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions = parsed_yaml.native_functions
grouped_native_functions = get_grouped_native_functions(native_functions)

Expand All @@ -57,6 +57,9 @@ def get_native_function_name(f):
# primarily view ops
supported = config.get("supported", [])

# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])

if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
Expand Down Expand Up @@ -105,6 +108,7 @@ def get_native_function_name(f):
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
"non_native": non_native,
},
f,
default_flow_style=False,
Expand All @@ -123,13 +127,13 @@ def get_native_function_name(f):


@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
class GenMlirLazyIr(torchgen.dest.GenLazyIR):

def lowering_function(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)
def lowering_function(self, schema, declaration_only=True):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"

if declaration_only:
return f"{signature};"

emplace_arguments = []
for arg in schema.positional_args:
Expand All @@ -149,7 +153,7 @@ def lowering_function(self, f):
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])

return f"""
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
{signature} {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
Expand All @@ -159,7 +163,7 @@ def lowering_function(self, f):
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});

return {schema.aten_name}_out;
}}
Expand All @@ -178,21 +182,21 @@ def generate_backend(
def gen_fallback_code(*args, **kwargs):
return ""

codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code

codegen.gen_lazy_tensor.run_gen_lazy_tensor(
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
backend_name="TorchMlir",
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
source_yaml=str(source_yaml),
output_dir=str(backend_path),
output_dir=str(backend_path.joinpath("generated")),
dry_run=False,
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
lazy_ir_cls=MlirLazyIr,
lazy_ir_generator=GenMlirLazyIr,
)

# Remove lazy_tensor_core imports
Expand All @@ -201,7 +205,7 @@ def gen_fallback_code(*args, **kwargs):
"sed",
"-i",
"/lazy_tensor_core/d",
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
str(backend_path.joinpath("generated", "LazyNativeFunctions.cpp")),
]
)

Expand Down Expand Up @@ -240,14 +244,14 @@ def extract_signatures(path):
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
backend_path.joinpath("generated", "GenLazyShapeInference.cpp").write_text(
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition

#include "LazyShapeInference.h"
#include "../utils/exception.h"
#include "../LazyShapeInference.h"
#include "../../utils/exception.h"
namespace torch {{
namespace lazy {{
{}
Expand Down
40 changes: 37 additions & 3 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ supported:
- empty
- expand
- fill_
- native_batch_norm
# - native_batch_norm_backward
- native_batch_norm_backward
- permute
- squeeze
- t
Expand All @@ -50,4 +49,39 @@ additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize
- native_batch_norm_backward
# - native_batch_norm_backward

# List of non native ops that we only want to do IR node class generation for
non_native:
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
opkind: ltc_device_data
cache_shape: false
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
opkind: at::prim::Constant
cache_shape: false
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
cache_shape: false
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
opkind: ltc_cast
cache_shape: false

# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
opkind: ltc_diagonal_view_update
cache_shape: false
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
opkind: ltc_select_view_update
cache_shape: false
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>

#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>
Expand Down
1 change: 1 addition & 0 deletions externals/pytorch
Submodule pytorch added at 9f3d6a
6 changes: 3 additions & 3 deletions python/torch_mlir/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")

add_library(torch_mlir_ltc_backend SHARED
base_lazy_backend/backend_impl.cpp
base_lazy_backend/LazyNativeFunctions.cpp
base_lazy_backend/generated/LazyNativeFunctions.cpp
base_lazy_backend/generated/GenLazyShapeInference.cpp
base_lazy_backend/generated/RegisterLazy.cpp
base_lazy_backend/LazyShapeInference.cpp
base_lazy_backend/GenLazyShapeInference.cpp
base_lazy_backend/mlir_lowering_context.cpp
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/RegisterLazy.cpp
)

add_dependencies(torch_mlir_ltc_backend
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,25 @@
namespace torch {
namespace lazy {

std::vector<Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
}
}
return shapes;
}

} // namespace lazy
} // namespace torch
7 changes: 7 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace lazy {
// clang-format off

TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
Expand All @@ -37,6 +38,8 @@ TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, con
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
Expand All @@ -63,6 +66,7 @@ TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::op
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
Expand All @@ -82,6 +86,9 @@ TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::opt
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);

TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);

// clang-format on

Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "backend_impl.h"
#include "ir_builder.h"
#include "mlir_lowering_context.h"

namespace torch {
Expand Down Expand Up @@ -72,6 +73,15 @@ TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
* */
void TorchMlirBackendImpl::PrepareToExit() const {}

/**
* IR Tracing
* */

const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
static const IrBuilder* builder = new TorchMlirIrBuilder();
return builder;
}

/**
* Data Transfer
* */
Expand All @@ -95,6 +105,16 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
return std::make_shared<TorchMlirBackendData>(device, shape);
}

BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
PRINT_FUNCTION();
auto* device_data_node = dynamic_cast<DeviceData*>(node);
if (!device_data_node) {
return nullptr;
}
return device_data_node->data;
}

at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
Expand Down
10 changes: 10 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
* */
virtual void PrepareToExit() const override;

/**
* IR Tracing
* */

const IrBuilder* GetIrBuilder() const override;

/**
* Configuration
* */
Expand All @@ -84,6 +90,10 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const override;

// Gets backend data if the node is a device data node. Otherwise returns
// nullptr.
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;

virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override;
Expand Down
Loading