Skip to content

Commit

Permalink
Generate MLIR with shape information via LTC frontend (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed May 26, 2022
1 parent 602d897 commit ce7a5bc
Show file tree
Hide file tree
Showing 13 changed files with 286 additions and 123 deletions.
2 changes: 1 addition & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def lowering_function(self, f):
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
return {schema.aten_name}_out;
Expand Down
2 changes: 1 addition & 1 deletion examples/ltc_backend_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main(device):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(5, 5)
self.fc1 = torch.nn.Linear(5, 10)

def forward(self, x):
out = self.fc1(x)
Expand Down
67 changes: 54 additions & 13 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ TorchMlirComputation::TorchMlirComputation(
const std::shared_ptr<torch::jit::Graph>& graph)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), num_results_(graph_->outputs().size()) {

// TODO(henrytu): Save parameter shape information.

for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
Expand Down Expand Up @@ -144,22 +141,18 @@ void TorchMlirLoweringContext::AddParameter(
ComputationPtr TorchMlirLoweringContext::Build() {
PRINT_FUNCTION();

// Insert return values into graph.
for (torch::jit::Value* output : root_tuple_) {
graph_->block()->registerOutput(output);
}

// Create jit::Function from jit::Graph.
c10::QualifiedName name("graph");
auto cu = std::make_shared<torch::jit::CompilationUnit>();
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
// may get mutated in the process.
auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy()));

// Generate MLIR.
MlirOperation func_op =
torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn);
MlirOperation func_op = torch_mlir::importJitFunctionAsFuncOp(
/*context=*/mlir_context_,
/*function=*/generate_jit_fn().get(),
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

// TODO(henrytu): Inject tensor shapes into func_op
return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
}

Expand Down Expand Up @@ -224,6 +217,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
TORCH_CHECK(
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
}
} else {
// Save parameter shape information.
param->setType(torch::jit::TensorType::create(
/*scalar_type=*/data->shape().scalar_type(),
/*device=*/c10::nullopt,
/*sizes=*/c10::VaryingShape<int64_t>(data->shape().sizes()),
/*strides=*/c10::VaryingShape<int64_t>(),
/*requires_grad=*/c10::nullopt));
}

it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
Expand All @@ -245,6 +246,46 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
return root_tuple_.size() - 1;
}

// Sync vector of c10::Argument with type specified from parallel list of
// jit::Value. There must be a 1:1 map between elements of args and values.
std::vector<c10::Argument> sync_argument_types(
const std::vector<c10::Argument>& args,
c10::ArrayRef<torch::jit::Value*> values) {
TORCH_CHECK(
args.size() == values.size(),
"Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ",
args.size(), ":", values.size(), " instead!");

std::vector<c10::Argument> updated_args;
for (unsigned i = 0; i < args.size(); i++) {
updated_args.push_back(args[i].cloneWithType(values[i]->type()));
}

return updated_args;
}

std::unique_ptr<torch::jit::Function>
TorchMlirLoweringContext::generate_jit_fn() const {
// IMPORTANT: We pass in a COPY of the graph into create_function, since it
// may get mutated in the process.
auto fn = std::make_unique<torch::jit::GraphFunction>(
c10::QualifiedName("graph"), graph_->copy(), nullptr);

c10::FunctionSchema schema = fn->getSchema();

// When constructing the default schema of a jit::GraphFunction, input and
// output shapes are stripped (via call to unshapedType(...)); however,
// since we want to have shape information in our MLIR, we'll add it back.
std::vector<c10::Argument> arguments =
sync_argument_types(schema.arguments(), graph_->inputs());
std::vector<c10::Argument> returns =
sync_argument_types(schema.returns(), graph_->outputs());

fn->setSchema(schema.cloneWithArguments(arguments).cloneWithReturns(returns));

return fn;
}

void TorchMlirLoweringContext::RegisterMlirDialects() {
// https://reviews.llvm.org/D88162
mlirRegisterAllDialects(mlir_context_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {

size_t AddResult(torch::jit::Value* op);

// Creates a jit::Function from the current jit::Graph. Input and output
// type information is patched to include shape.
std::unique_ptr<torch::jit::Function> generate_jit_fn() const;

void RegisterMlirDialects();

std::shared_ptr<torch::jit::Graph> graph_;
Expand Down
127 changes: 94 additions & 33 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,76 @@
namespace torch {
namespace lazy {

TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const std::vector<c10::TypePtr> tensor_types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
auto builtin =
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
CHECK(sv);

TorchMlirOpVector results;
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
// Op returns multiple values.
const auto tuple_call_result = sv->asTuple({}, *function);
for (const auto& tuple_component : tuple_call_result) {
auto tuple_component_sv =
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
results.push_back(tuple_component_sv->getValue());
}
} else {
// Op returns single value.
results.push_back(sv->getValue());
}

// Insert known tensor type information.
unsigned tensor_type_idx = 0;
for (jit::Value* value : results) {
if (value->type()->kind() == c10::TypeKind::TensorType) {
TORCH_CHECK(
tensor_type_idx < tensor_types.size(),
"Tensor corresponding to JIT SSA value %", value->debugName(),
" corresponds to result #", tensor_type_idx, ", but we only have ",
tensor_types.size(), " known types!");

value->setType(tensor_types[tensor_type_idx++]);
}
}

// Ensure that we use up all the known tensor type information available.
TORCH_CHECK(
tensor_type_idx == tensor_types.size(), tensor_type_idx,
" known types were injected into jit::Value, but ", tensor_types.size(),
" were provided from lazy::Node!");

return results;
}

TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
std::vector<c10::TypePtr> tensor_types;

// Generate types with fixed tensor shape information.
for (const Shape& shape : result_shapes) {
tensor_types.push_back(torch::jit::TensorType::create(
/*scalar_type=*/shape.scalar_type(),
/*device=*/c10::nullopt,
/*sizes=*/c10::VaryingShape<int64_t>(shape.sizes()),
/*strides=*/c10::VaryingShape<int64_t>(),
/*requires_grad=*/c10::nullopt));
}

return LowerTorchMlirBuiltin(
function, sym, tensor_types, arguments, kwarguments);
}

class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
public:
TorchMlirNodeLowering(
Expand Down Expand Up @@ -189,12 +259,20 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(
function_, node->op().op, arguments, kwarguments);
function_, node->op().op, node->shapes(), arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<torch::jit::NamedValue>& arguments,
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments);
return LowerTorchMlirBuiltin(
function_, sym, result_shapes, arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<c10::TypePtr> types,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
return LowerTorchMlirBuiltin(function_, sym, types, arguments, kwarguments);
}

TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) {
Expand Down Expand Up @@ -222,7 +300,7 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
dest_arguments.emplace_back(node->stride());
dest_arguments.emplace_back(node->storage_offset());
TorchMlirOpVector as_strided_out =
LowerBuiltin(at::aten::as_strided, dest_arguments);
LowerBuiltin(at::aten::as_strided, node->shapes(), dest_arguments);
CHECK_EQ(as_strided_out.size(), 1);
torch::jit::Value* as_strided = as_strided_out.front();
GenerateCopy(as_strided, loctx()->GetOutputOp(input_op));
Expand Down Expand Up @@ -266,7 +344,7 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.emplace_back(node->dtype());
return LowerBuiltin(at::aten::to, arguments);
return LowerBuiltin(at::aten::to, node->shapes(), arguments);
}

TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) {
Expand Down Expand Up @@ -383,13 +461,16 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx()->GetOutputOp(node->operand(0)));
arguments.push_back(node->output_size());
return LowerBuiltin(at::aten::reshape, arguments);
return LowerBuiltin(at::aten::reshape, node->shapes(), arguments);
}

torch::jit::Value* GenerateClone(torch::jit::Value* val) {
std::vector<torch::jit::NamedValue> clone_arguments;
clone_arguments.emplace_back(val);
TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments);

// Type of cloned value should be identical to the original one.
TorchMlirOpVector cloned =
LowerBuiltin(at::aten::clone, {val->type()}, clone_arguments);
CHECK_EQ(cloned.size(), 1);
return cloned.front();
}
Expand All @@ -398,7 +479,9 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(at::aten::copy_, arguments);
LowerBuiltin(
at::aten::copy_, c10::ArrayRef<Shape>({/*shape goes here*/}),
arguments);
}

torch::jit::Value* GenerateSlice(
Expand All @@ -410,7 +493,9 @@ class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface {
arguments.emplace_back(start);
arguments.emplace_back(end);
arguments.emplace_back(step);
TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments);
TorchMlirOpVector selected = LowerBuiltin(
at::aten::slice, c10::ArrayRef<Shape>({/*shape goes here*/}),
arguments);
CHECK_EQ(selected.size(), 1);
return selected.front();
}
Expand All @@ -424,29 +509,5 @@ TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) {
"TorchMlirNodeLowering",
static_cast<torch::lazy::TorchMlirLoweringContext*>(loctx));
}

TorchMlirOpVector LowerTorchMlirBuiltin(
std::shared_ptr<torch::jit::GraphFunction> function, c10::Symbol sym,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments) {
auto builtin =
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
CHECK(sv);
if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
const auto tuple_call_result = sv->asTuple({}, *function);
TorchMlirOpVector tuple_result;
for (const auto& tuple_component : tuple_call_result) {
auto tuple_component_sv =
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
tuple_result.push_back(tuple_component_sv->getValue());
}
return tuple_result;
}
return {sv->getValue()};
}

} // namespace lazy
} // namespace torch
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;

TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirFunction function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue>& arguments,
const std::vector<torch::jit::NamedValue>& kwarguments = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ using namespace torch_mlir;

MlirOperation torch_mlir::importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute) {
std::function<MlirAttribute(int)> getArgAttribute,
const ImportOptions &importOptions) {
// Useful for debugging:
// graph->dump();
MlirLocation loc = mlirLocationUnknownGet(context);
MlirType functionType =
getFunctionTypeFromSchema(context, function->getSchema());
getFunctionTypeFromSchema(context, function->getSchema(), importOptions);
// Use the function's qualified name from the compilation unit.
// This is a stable linkage name that matches Python module lookup
// conventions (see compilation unit import in IValueImporter for more details
Expand Down Expand Up @@ -70,7 +71,7 @@ MlirOperation torch_mlir::importJitFunctionAsFuncOp(
};
MlirBlock block = importBlock(
context, torch::jit::toGraphFunction(*function).graph()->block(),
createTerminator, inputTypes);
createTerminator, inputTypes, importOptions);
mlirRegionAppendOwnedBlock(bodyRegion, block);
return func;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <memory>

#include "import_options.h"
#include "node_importer.h"
#include "pybind.h"

Expand Down Expand Up @@ -42,7 +43,8 @@ namespace torch_mlir {
TORCH_API MlirOperation importJitFunctionAsFuncOp(
MlirContext context, torch::jit::Function *function,
std::function<MlirAttribute(int)> getArgAttribute =
[](int) -> MlirAttribute { return {nullptr}; });
[](int) -> MlirAttribute { return {nullptr}; },
const ImportOptions &importOptions = {});

} // namespace torch_mlir

Expand Down
Loading

0 comments on commit ce7a5bc

Please sign in to comment.