Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for LTC Input/Output Mapping #764

Merged
1 change: 1 addition & 0 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ blacklist:
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)

# Additional ops which autogen is supported for but don't compile yet
- detach
- item
- size
- where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,5 @@
namespace torch {
namespace lazy {

std::vector<Shape> compute_shape_detach(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

} // namespace lazy
} // namespace torch
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self,
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_detach(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
Expand Down
48 changes: 40 additions & 8 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ namespace lazy {

TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph)
const std::shared_ptr<torch::jit::Graph>& graph,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), num_results_(graph_->outputs().size()) {
graph_(graph), input_output_aliases_(input_output_aliases),
num_results_(graph_->outputs().size()) {
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
Expand Down Expand Up @@ -76,10 +78,21 @@ std::string TorchMlirComputation::to_string() const {
};

std::stringstream ss;
ss << "JIT Graph: \n"
<< graph_->toString() << "\n\n"
<< "MLIR: \n";

// JIT Graph
ss << "JIT Graph: \n" << graph_->toString() << "\n\n";

// MLIR
ss << "MLIR: \n";
mlirOperationPrint(func_op_, print_callback, &ss);
ss << "\n";

// Input/Output Mapping
ss << "Input/Output Alias Mapping: \n";
for (InputOutputAlias input_output_alias : input_output_aliases_) {
ss << "Output: " << input_output_alias.output_index
<< " -> Input param: " << input_output_alias.param_number << std::endl;
}

return ss.str();
}
Expand Down Expand Up @@ -116,9 +129,27 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
}

// Get the shape of the result tuple component, given by index.
torch::lazy::Shape
c10::optional<torch::lazy::Shape>
TorchMlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
TORCH_CHECK(
index < root_tuple_.size(), "Tried getting result shape at index ", index,
" which is out of bounds!");

torch::jit::Value* output = root_tuple_[index];

if (c10::TensorTypePtr tensor_type =
output->type()->cast<c10::TensorType>()) {
auto scalar_type = tensor_type->scalarType();
auto sizes = tensor_type->sizes().concrete_sizes();

// Not guaranteed to have concrete size, so we need to check it exists.
if (scalar_type && sizes) {
return Shape(scalar_type.value(), c10::ArrayRef<int64_t>(sizes.value()));
}
}

// No shape information.
return c10::nullopt;
}

size_t TorchMlirLoweringContext::AddResult(const Output& output) {
Expand Down Expand Up @@ -153,7 +184,8 @@ ComputationPtr TorchMlirLoweringContext::Build() {
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
return std::make_shared<TorchMlirComputation>(
func_op, mlir_context_, graph_, input_output_aliases_);
}

torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ class TORCH_API TorchMlirNodeLoweringInterface {

class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
using InputOutputAliases = LoweringContext::InputOutputAliases;
using InputOutputAlias = LoweringContext::InputOutputAlias;

TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph);
const std::shared_ptr<torch::jit::Graph>& graph,
InputOutputAliases input_output_aliases);

int parameters_size() const override;

Expand All @@ -67,6 +71,7 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
MlirOperation func_op_;
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
InputOutputAliases input_output_aliases_;
unsigned num_results_;
};

Expand All @@ -80,7 +85,7 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
torch::lazy::Util::EmissionMap emit_status);

// Get the shape of the result tuple component, given by index.
torch::lazy::Shape GetResultShape(size_t index) const override;
c10::optional<torch::lazy::Shape> GetResultShape(size_t index) const override;

// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
Expand Down