Skip to content

Commit

Permalink
Got LTC working until compile (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim authored Mar 24, 2022
1 parent 5f73f71 commit e173444
Show file tree
Hide file tree
Showing 22 changed files with 1,248 additions and 598 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ __pycache__
/python/torch_mlir/csrc/backend/LazyLazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/LazyShapeInference.cpp
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/backend/RegisterLazy.cpp
203 changes: 114 additions & 89 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass
from pathlib import Path
from shutil import which
Expand All @@ -21,13 +22,11 @@
from codegen.api.lazy import LazyIrSchema
from codegen.gen import get_grouped_native_functions, parse_native_yaml
from codegen.model import NativeFunctionsGroup
from codegen.gen_backend_stubs import parse_backend_yaml
from codegen.api.types import kernel_signature
from codegen.dest.lazy_ir import ComputeShapeSignature
from codegen.gen_lazy_tensor import parse_full_codegen_ops


def generate_native_functions(aten_ops_file: Path, out_file: Path):
def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
print("Generating Native Functions Yaml")

native_yaml_path = TORCH_DIR.joinpath(
Expand All @@ -44,46 +43,15 @@ def get_native_function_name(f):

aten_funcs = set(map(get_native_function_name, grouped_native_functions))

with config_path.open() as f:
config = yaml.load(f, yaml.CLoader)

# List of unsupported ops in LTC autogen because of some error
blacklist = {
"arange", # Error: Code below assumes there is at least one tensor arg
"bernoulli", # Error: TODO add support for type BaseType(name=<BaseTy.Generator: 1>)
"bernoulli_", # Error: TODO add support for type BaseType(name=<BaseTy.Generator: 1>)
"cat", # Error: TODO not sure if there are other valid types to handle here
"clone", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"contiguous", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"empty_like", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"empty.memory_format", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"index.Tensor", # Error: TODO not sure if there are other valid types to handle here
"index_put", # Error: TODO not sure if there are other valid types to handle here
"index_put_", # Error: TODO not sure if there are other valid types to handle here
"ones", # Error: Code below assumes there is at least one tensor arg
"ones_like", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"resize_", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"stack", # Error: TODO not sure if there are other valid types to handle here
"to.dtype", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"to.other", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"uniform_", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
"zeros", # Error: Code below assumes there is at least one tensor arg
"zeros_like", # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
}

# Additional ops which autogen is supported for but don't compile yet
blacklist |= {"item", "size", "where"}
blacklist = config.get("blacklist", [])

# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = {
"expand",
# "native_batch_norm_backward",
"native_batch_norm",
"permute",
"repeat",
"squeeze",
"t",
"unsqueeze",
"view",
}
supported = config.get("supported", [])

if which("rg") is not None: # use ripgrep if available as its much faster
cmd = ["rg", "-o", "-N", r"aten::[0-9a-zA-Z_\.]+"]
Expand All @@ -92,7 +60,7 @@ def get_native_function_name(f):

output = (
subprocess.check_output(
cmd + [str(aten_ops_file)],
cmd + [str(torch_ops_file)],
encoding="utf-8",
)
.strip()
Expand Down Expand Up @@ -123,6 +91,9 @@ def get_native_function_name(f):

opnames = sorted(set(ops))

# Additional ops to support that are not supported by Torch-MLIR explicitly
supported_ops.extend(config.get("additional_ops", []))

with out_file.open("w") as f:
yaml.dump(
{
Expand Down Expand Up @@ -167,7 +138,10 @@ def lowering_body(self, f):


def generate_backend(
source_yaml: Path, backend_path: Path, parsed_yaml: dict, grouped_native_functions: list
source_yaml: Path,
backend_path: Path,
parsed_yaml: dict,
grouped_native_functions: list,
):
print("Running Lazy Tensor Autogen")

Expand All @@ -178,6 +152,7 @@ def gen_fallback_code(*args, **kwargs):
codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code

codegen.gen_lazy_tensor.run(
backend_name="TorchMlir",
source_yaml=str(source_yaml),
output_dir=str(backend_path),
dry_run=False,
Expand All @@ -201,63 +176,104 @@ def gen_fallback_code(*args, **kwargs):
]
)

# Autogenerate shape inference placeholders
# programmatically check shape inference declarations
import re

sig_re = re.compile(f"std::vector<Shape> (?P<name>[_a-zA-Z0-9]+)\((?P<signature>.+)\);")
upstream_shape_inference_decls = set(
(name, signature)
for name, signature in sig_re.findall(
TORCH_DIR.joinpath("torch", "csrc", "lazy", "core", "shape_inference.h").read_text()
)
sig_re = re.compile(
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
shape_inference_decls = backend_path.joinpath("LazyShapeInference.h").read_text()

shape_inference_defs = []
for name, signature in sig_re.findall(shape_inference_decls):
if (name, signature) in upstream_shape_inference_decls:
continue
global_signatures = {}

def extract_signatures(path):
signatures = set()
for name, args in sig_re.findall(path.read_text()):
signature = re.sub(r"\s+", "", f"{name}({args})")
global_signatures[signature] = (name, args)
signatures.add(signature)
return signatures

upstream_shape_inference_decls = extract_signatures(
TORCH_DIR.joinpath("torch", "csrc", "lazy", "core", "shape_inference.h")
)
assert len(upstream_shape_inference_decls) > 0
shape_inference_decls = extract_signatures(
backend_path.joinpath("LazyShapeInference.h")
)
assert len(shape_inference_decls) > 0
shape_inference_defs = extract_signatures(
backend_path.joinpath("LazyShapeInference.cpp")
)
assert len(shape_inference_defs) > 0
assert len(shape_inference_decls) > len(shape_inference_defs)

shape_inference_defs.append(
missing_defs = (
shape_inference_decls
- upstream_shape_inference_decls
- shape_inference_defs
)
if missing_defs:
backend_path.joinpath("GenLazyShapeInference.cpp").write_text(
dedent(
f"""
std::vector<Shape> {name}({signature}) {{
UNIMPLEMENTED_ERROR("{name}");
}}
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel or shape definition
#include "LazyShapeInference.h"
#include "../utils/exception.h"
namespace torch {{
namespace lazy {{
{}
}} // namespace lazy
}} // namespace torch
"""
).format(
"".join(
dedent(
f"""
std::vector<Shape> {name}({args}) {{
UNIMPLEMENTED_FUNCTION_ERROR();
}}
"""
)
for name, args in map(
global_signatures.get, sorted(missing_defs)
)
)
)
)

backend_path.joinpath("LazyShapeInference.cpp").write_text(
dedent(
"""
// This file contains autogenerated Lazy Shape Inference placeholders
// for ops that dont have a corresponding structured kernel
#include "LazyShapeInference.h"
#include "../utils/exception.h"
namespace torch {{
namespace lazy {{
{}
}} // namespace lazy
}} // namespace torch
"""
).format("".join(shape_inference_defs))
)
unnecessary_defs = shape_inference_defs - shape_inference_decls
if unnecessary_defs:
unnecessary_defs = "\n\t".join(
f"{name}({args})"
for name, args in map(global_signatures.get, unnecessary_defs)
)
warnings.warn(
f"Unnecessary shape inference definitions found for:\n\t{unnecessary_defs}"
)


def main(args):
script_path = Path(__file__).resolve()
aten_ops_file = TORCH_MLIR_DIR.joinpath(
"include", "torch-mlir", "Dialect", "Torch", "IR", "GeneratedAtenOps.td"
config_path = (
Path(__file__).resolve().parent.joinpath("autogen_ltc_backend.yaml")
)
assert aten_ops_file.exists()
torch_ops_file = TORCH_MLIR_DIR.joinpath(
"include",
"torch-mlir",
"Dialect",
"Torch",
"IR",
"GeneratedTorchOps.td",
)
assert torch_ops_file.exists()
native_functions = TORCH_MLIR_DIR.joinpath(
"generated_native_functions.yaml"
)
backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "backend"
)
assert backend_path.is_dir()

prev_hash = None
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
Expand All @@ -266,23 +282,32 @@ def main(args):

m = hashlib.sha256()
m.update(script_path.read_bytes())
m.update(aten_ops_file.read_bytes())
m.update(config_path.read_bytes())
m.update(torch_ops_file.read_bytes())
if native_functions.exists():
m.update(native_functions.read_bytes())

shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
if shape_inference_headers.exists():
m.update(shape_inference_headers.read_bytes())

shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
if shape_inference_defs.exists():
m.update(shape_inference_defs.read_bytes())

new_hash = m.hexdigest().strip()

if args.force or new_hash != prev_hash:
hash_file.write_text(new_hash)
parsed_yaml, grouped_native_functions = generate_native_functions(
aten_ops_file, native_functions
config_path, torch_ops_file, native_functions
)

backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "backend"
)
generate_backend(
native_functions, backend_path, parsed_yaml, grouped_native_functions
native_functions,
backend_path,
parsed_yaml,
grouped_native_functions,
)


Expand Down
52 changes: 52 additions & 0 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
blacklist:
# List of unsupported ops in LTC autogen because of some error
- arange # Error: Code below assumes there is at least one tensor arg
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- full # Error: Code below assumes there is at least one tensor arg
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
- ones # Error: Code below assumes there is at least one tensor arg
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- stack # Error: TODO not sure if there are other valid types to handle here
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- zeros # Error: Code below assumes there is at least one tensor arg
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)

# Additional ops which autogen is supported for but don't compile yet
- item
- size
- where
- copy_
- _to_copy
- log_softmax # Not inherently differentiable. Needs to be decomposed.
- linear # Not inherently differentiable. Needs to be decomposed.

# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported:
# - bernoulli
# - bernoulli_
- cat
- clone
- empty
- expand
- fill_
# - native_batch_norm_backward
- native_batch_norm
- permute
- repeat
- squeeze
- t
- unsqueeze
- view

additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
- _copy_from
- _copy_from_and_resize
4 changes: 4 additions & 0 deletions python/torch_mlir/csrc/.clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
BasedOnStyle: LLVM
AlignAfterOpenBracket: AlwaysBreak # BlockIndent
PointerAlignment: Left
ReflowComments: false
10 changes: 6 additions & 4 deletions python/torch_mlir/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/GenLazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
tensor_aten_ops.cpp
)

target_link_libraries(torch_mlir_ltc_backend
Expand All @@ -40,10 +42,10 @@ target_link_libraries(torch_mlir_ltc_backend
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(torch_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/"
OUTPUT_NAME _MLIR_LTC
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
OUTPUT_NAME lib_mlir_ltc
PREFIX ""
SUFFIX ".so"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
LINK_FLAGS "-rdynamic"
)

2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Torch-MLIR Lazy Tensor Core Backend
#Torch - MLIR Lazy Tensor Core Backend

Contained within this directory are the components that implements the
Torch-MLIR LTC backend.
Expand Down
Loading

0 comments on commit e173444

Please sign in to comment.