Skip to content

Commit

Permalink
Fix LTC Decoupling (#815)
Browse files Browse the repository at this point in the history
* Initial changes

* Fix up native functions

* Further fix decoupling

* Remove unnecessary ops

* Formatting and copyright banners:

* Add pytorch submodule
  • Loading branch information
antoniojkim committed Jul 7, 2022
1 parent 7b2597e commit e6bffea
Show file tree
Hide file tree
Showing 22 changed files with 733 additions and 217 deletions.
13 changes: 4 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,10 @@ __pycache__
# Bazel
bazel-*

# Libraries
*.so
*.a

# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
/python/torch_mlir/csrc/base_lazy_backend/generated

# Example backend
examples/ltc_backend/ltc_backend/_EXAMPLE_MLIR_BACKEND.cpython-37m-x86_64-linux-gnu.so
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "externals/pytorch"]
path = externals/pytorch
url = https://github.com/pytorch/pytorch.git
shallow = true
60 changes: 32 additions & 28 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,32 @@
import yaml

TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.parent.joinpath("pytorch")
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")

sys.path.append(str(TORCH_DIR.joinpath("tools")))
sys.path.append(str(TORCH_DIR))

# PyTorch's LTC backend autogen script
import codegen.dest.lazy_ir
import codegen.gen_lazy_tensor
from codegen.api.lazy import LazyIrSchema
from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup


def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"


def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
print("Generating Native Functions Yaml")

native_yaml_path = TORCH_DIR.joinpath(
"aten", "src", "ATen", "native", "native_functions.yaml"
)
native_path = TORCH_DIR.joinpath("aten", "src", "ATen", "native")
native_yaml_path = native_path.joinpath("native_functions.yaml")
tags_yaml_path = native_path.joinpath("tags.yaml")

parsed_yaml = parse_native_yaml(native_yaml_path)
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions = parsed_yaml.native_functions
grouped_native_functions = get_grouped_native_functions(native_functions)

Expand All @@ -57,6 +57,9 @@ def get_native_function_name(f):
# primarily view ops
supported = config.get("supported", [])

# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])

if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
else:
Expand Down Expand Up @@ -105,6 +108,7 @@ def get_native_function_name(f):
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
"non_native": non_native,
},
f,
default_flow_style=False,
Expand All @@ -123,13 +127,13 @@ def get_native_function_name(f):


@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
class GenMlirLazyIr(torchgen.dest.GenLazyIR):

def lowering_function(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)
def lowering_function(self, schema, declaration_only=True):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"

if declaration_only:
return f"{signature};"

emplace_arguments = []
for arg in schema.positional_args:
Expand All @@ -149,7 +153,7 @@ def lowering_function(self, f):
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])

return f"""
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{
{signature} {{
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
Expand All @@ -159,7 +163,7 @@ def lowering_function(self, f):
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
}}
Expand All @@ -178,21 +182,21 @@ def generate_backend(
def gen_fallback_code(*args, **kwargs):
return ""

codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code
torchgen.dest.lazy_ir.gen_fallback_code = gen_fallback_code

codegen.gen_lazy_tensor.run_gen_lazy_tensor(
torchgen.gen_lazy_tensor.run_gen_lazy_tensor(
backend_name="TorchMlir",
aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")),
source_yaml=str(source_yaml),
output_dir=str(backend_path),
output_dir=str(backend_path.joinpath("generated")),
dry_run=False,
impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")),
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
shape_inference_hdr=str(backend_path.joinpath("LazyShapeInference.h")),
lazy_ir_cls=MlirLazyIr,
lazy_ir_generator=GenMlirLazyIr,
)

# Remove lazy_tensor_core imports
Expand All @@ -201,7 +205,7 @@ def gen_fallback_code(*args, **kwargs):
"sed",
"-i",
"/lazy_tensor_core/d",
str(backend_path.joinpath("LazyNativeFunctions.cpp")),
str(backend_path.joinpath("generated", "LazyNativeFunctions.cpp")),
]
)

Expand Down Expand Up @@ -240,14 +244,14 @@ def extract_signatures(path):
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
backend_path.joinpath("generated", "GenLazyShapeInference.cpp").write_text(
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "LazyShapeInference.h"
#include "../utils/exception.h"
#include "../LazyShapeInference.h"
#include "../../utils/exception.h"
namespace torch {{
namespace lazy {{
{}
Expand Down
40 changes: 37 additions & 3 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ supported:
- empty
- expand
- fill_
- native_batch_norm
# - native_batch_norm_backward
- native_batch_norm_backward
- permute
- squeeze
- t
Expand All @@ -50,4 +49,39 @@ additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize
- native_batch_norm_backward
# - native_batch_norm_backward

# List of non native ops that we only want to do IR node class generation for
non_native:
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
opkind: ltc_device_data
cache_shape: false
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
opkind: at::prim::Constant
cache_shape: false
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
cache_shape: false
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
opkind: ltc_cast
cache_shape: false

# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
opkind: ltc_diagonal_view_update
cache_shape: false
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
opkind: ltc_select_view_update
cache_shape: false
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>

#include <torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/generated/LazyNativeFunctions.h>
#include <torch_mlir/csrc/base_lazy_backend/backend_impl.h>
#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>
Expand Down
1 change: 1 addition & 0 deletions externals/pytorch
Submodule pytorch added at 9f3d6a
6 changes: 3 additions & 3 deletions python/torch_mlir/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")

add_library(torch_mlir_ltc_backend SHARED
base_lazy_backend/backend_impl.cpp
base_lazy_backend/LazyNativeFunctions.cpp
base_lazy_backend/generated/LazyNativeFunctions.cpp
base_lazy_backend/generated/GenLazyShapeInference.cpp
base_lazy_backend/generated/RegisterLazy.cpp
base_lazy_backend/LazyShapeInference.cpp
base_lazy_backend/GenLazyShapeInference.cpp
base_lazy_backend/mlir_lowering_context.cpp
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/RegisterLazy.cpp
)

add_dependencies(torch_mlir_ltc_backend
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,25 @@
namespace torch {
namespace lazy {

std::vector<Shape> compute_shape_native_batch_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
const c10::optional<at::Tensor>& running_mean,
const c10::optional<at::Tensor>& running_var, bool training,
double momentum, double eps) {
std::vector<Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
if (running_mean.has_value()) {
shapes.emplace_back(
running_mean.value().scalar_type(), running_mean.value().sizes().vec());
if (running_var.has_value()) {
shapes.emplace_back(
running_var.value().scalar_type(), running_var.value().sizes().vec());
}
}
return shapes;
}

} // namespace lazy
} // namespace torch
7 changes: 7 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace lazy {
// clang-format off

TORCH_API std::vector<Shape> compute_shape___and__(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape__reshape_alias(const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride);
TORCH_API std::vector<Shape> compute_shape__shape_as_tensor(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape__unsafe_view(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_abs(const at::Tensor & self);
Expand All @@ -37,6 +38,8 @@ TORCH_API std::vector<Shape> compute_shape_bincount(const at::Tensor & self, con
TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self, at::IntArrayRef size);
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_convolution(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
Expand All @@ -63,6 +66,7 @@ TORCH_API std::vector<Shape> compute_shape_mean(const at::Tensor & self, c10::op
TORCH_API std::vector<Shape> compute_shape_mul(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_mul_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_native_layer_norm(const at::Tensor & input, at::IntArrayRef normalized_shape, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, double eps);
TORCH_API std::vector<Shape> compute_shape_new_empty(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_ones(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_new_zeros(const at::Tensor & self, at::IntArrayRef size, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory);
TORCH_API std::vector<Shape> compute_shape_rand_like(const at::Tensor & self, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory, c10::optional<at::MemoryFormat> memory_format);
Expand All @@ -82,6 +86,9 @@ TORCH_API std::vector<Shape> compute_shape_sum(const at::Tensor & self, c10::opt
TORCH_API std::vector<Shape> compute_shape_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1);
TORCH_API std::vector<Shape> compute_shape_type_as(const at::Tensor & self, const at::Tensor & other);
TORCH_API std::vector<Shape> compute_shape_var(const at::Tensor & self, bool unbiased);
TORCH_API std::vector<Shape> compute_shape_zero_(at::Tensor & self);

TORCH_API std::vector<Shape> compute_shape_native_batch_norm(const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps);

// clang-format on

Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "../utils/debug.h"
#include "../utils/exception.h"
#include "backend_impl.h"
#include "ir_builder.h"
#include "mlir_lowering_context.h"

namespace torch {
Expand Down Expand Up @@ -72,6 +73,15 @@ TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
* */
void TorchMlirBackendImpl::PrepareToExit() const {}

/**
* IR Tracing
* */

const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
static const IrBuilder* builder = new TorchMlirIrBuilder();
return builder;
}

/**
* Data Transfer
* */
Expand All @@ -95,6 +105,16 @@ BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
return std::make_shared<TorchMlirBackendData>(device, shape);
}

BackendDataPtr
TorchMlirBackendImpl::GetComputationDataFromNode(Node* node) const {
PRINT_FUNCTION();
auto* device_data_node = dynamic_cast<DeviceData*>(node);
if (!device_data_node) {
return nullptr;
}
return device_data_node->data;
}

at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const {
Expand Down
10 changes: 10 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
* */
virtual void PrepareToExit() const override;

/**
* IR Tracing
* */

const IrBuilder* GetIrBuilder() const override;

/**
* Configuration
* */
Expand All @@ -84,6 +90,10 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const override;

// Gets backend data if the node is a device data node. Otherwise returns
// nullptr.
virtual BackendDataPtr GetComputationDataFromNode(Node*) const override;

virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
c10::optional<at::ScalarType> logical_scalar_type) const override;
Expand Down
Loading

0 comments on commit e6bffea

Please sign in to comment.