Skip to content

Commit

Permalink
Enable support for LTC Input/Output Mapping (#764)
Browse files Browse the repository at this point in the history
* Save InputOutputAliases to TorchMlirComputation

* Implement GetResultShape for TorchMlirLoweringContext

* Use optional return type for GetResultShape

* Remove support for aten::detach

With this op enabled, tensors were being copied, which resulted in incorrect aliasing.

* Add newline before printing I/O alias mapping

* Changed printout to use "Input param" as label instead of "Input"

* Remote shape inference function for aten::detach

* Moved implementation of SetUpAlias to MlirLoweringContext

As part of this change, TorchMlirComputation has been moved to the end of mlir_lowering_context.h so that it can access some new structs in TorchMlirLoweringContext

* Use updated PyTorch API

* Remove GetResultShape

Complements this upstream PyTorch PR: pytorch/pytorch#75828

This PR adds support for mapping input and output tensors which alias each other. (e.g. maps input weight tensor in parameter to the same tensor in output after a training iteration)

MLIR: 
func @graph(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1],si64>, ..., %arg6: !torch.vtensor<[10,5],f32>, %arg7: !torch.vtensor<[10],f32>, ...) {
  ...
  return %arg0, %arg1, %17, %23, ... : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[10,5],f32>, !torch.vtensor<[10],f32>, ...
}

Input/Output Alias Mapping: 
Output: 0 -> Input: 0
Output: 1 -> Input: 1
Output: 2 -> Input: 6
Output: 3 -> Input: 7
The aten::detach op has also been disabled in this PR to fix the issue of tensors not aliasing properly due to copying.
  • Loading branch information
henrytwo committed Jul 29, 2022
1 parent 05f3441 commit 8627bed
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 101 deletions.
1 change: 0 additions & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def extract_signatures(path):
shape_inference_defs = extract_signatures(
backend_path.joinpath("LazyShapeInference.cpp")
)
assert len(shape_inference_defs) > 0
assert len(shape_inference_decls) > len(shape_inference_defs)

missing_defs = (
Expand Down
1 change: 1 addition & 0 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ blacklist:
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)

# Additional ops which autogen is supported for but don't compile yet
- detach
- item
- size
- where
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,5 @@
namespace torch {
namespace lazy {

std::vector<Shape> compute_shape_detach(const at::Tensor& self) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

} // namespace lazy
} // namespace torch
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ TORCH_API std::vector<Shape> compute_shape_broadcast_to(const at::Tensor & self,
TORCH_API std::vector<Shape> compute_shape_bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right);
TORCH_API std::vector<Shape> compute_shape_constant_pad_nd(const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value);
TORCH_API std::vector<Shape> compute_shape_conv2d(const at::Tensor & input, const at::Tensor & weight, const c10::optional<at::Tensor> & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups);
TORCH_API std::vector<Shape> compute_shape_detach(const at::Tensor & self);
TORCH_API std::vector<Shape> compute_shape_div(const at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_div_(at::Tensor & self, const at::Scalar & other);
TORCH_API std::vector<Shape> compute_shape_dropout(const at::Tensor & input, double p, bool train);
Expand Down
162 changes: 100 additions & 62 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,63 +27,6 @@
namespace torch {
namespace lazy {

///////////////////////////////////////////////////////////////////////////////
// TorchMlir Computation
///////////////////////////////////////////////////////////////////////////////

TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), num_results_(graph_->outputs().size()) {
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
}

int TorchMlirComputation::parameters_size() const {
return parameter_names_.size();
}

const std::vector<torch::lazy::Shape>&
TorchMlirComputation::parameter_shapes() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return parameter_shapes_;
}

const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}

const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return result_shape_;
}

unsigned TorchMlirComputation::num_results() const { return num_results_; }

MlirOperation TorchMlirComputation::func_op() const { return func_op_; }

std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this
// case, the string stream where we'll be accumulating the strings.
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
*ss_ptr << std::string(part.data, part.length);
};

std::stringstream ss;
ss << "JIT Graph: \n"
<< graph_->toString() << "\n\n"
<< "MLIR: \n";
mlirOperationPrint(func_op_, print_callback, &ss);

return ss.str();
}

///////////////////////////////////////////////////////////////////////////////
// TorchMlir Lowering Context
///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -115,10 +58,34 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
RegisterMlirDialects();
}

// Get the shape of the result tuple component, given by index.
torch::lazy::Shape
TorchMlirLoweringContext::GetResultShape(size_t index) const {
UNIMPLEMENTED_FUNCTION_ERROR();
void TorchMlirLoweringContext::SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index, bool must_alias) {
input_output_aliases_.push_back(
{output_index, param_number, param_index, must_alias});
}

bool TorchMlirLoweringContext::CheckResultShape(
const BackendDataPtr& parameter_data, size_t result_idx) {
TORCH_CHECK(
result_idx < root_tuple_.size(), "Tried getting result shape at index ",
result_idx, " which is out of bounds!");

torch::jit::Value* output = root_tuple_[result_idx];

if (c10::TensorTypePtr tensor_type =
output->type()->cast<c10::TensorType>()) {
auto scalar_type = tensor_type->scalarType();
auto sizes = tensor_type->sizes().concrete_sizes();

// Not guaranteed to have concrete size, so we need to check it exists.
if (scalar_type && sizes) {
return Shape(parameter_data->shape()) ==
Shape(scalar_type.value(), c10::ArrayRef<int64_t>(sizes.value()));
}
}

return false;
}

size_t TorchMlirLoweringContext::AddResult(const Output& output) {
Expand Down Expand Up @@ -153,7 +120,8 @@ ComputationPtr TorchMlirLoweringContext::Build() {
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});

return std::make_shared<TorchMlirComputation>(func_op, mlir_context_, graph_);
return std::make_shared<TorchMlirComputation>(
func_op, mlir_context_, graph_, input_output_aliases_);
}

torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
Expand Down Expand Up @@ -292,5 +260,75 @@ void TorchMlirLoweringContext::RegisterMlirDialects() {
torchMlirRegisterAllDialects(mlir_context_);
}

///////////////////////////////////////////////////////////////////////////////
// TorchMlir Computation
///////////////////////////////////////////////////////////////////////////////

TorchMlirComputation::TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
InputOutputAliases input_output_aliases)
: func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)),
graph_(graph), input_output_aliases_(input_output_aliases),
num_results_(graph_->outputs().size()) {
for (torch::jit::Value* input : graph_->inputs()) {
parameter_names_.push_back(input->debugName());
}
}

int TorchMlirComputation::parameters_size() const {
return parameter_names_.size();
}

const std::vector<torch::lazy::Shape>&
TorchMlirComputation::parameter_shapes() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return parameter_shapes_;
}

const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}

const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
throw std::runtime_error(
"todo(whc) implement ts computation shapes or change interface");
return result_shape_;
}

unsigned TorchMlirComputation::num_results() const { return num_results_; }

MlirOperation TorchMlirComputation::func_op() const { return func_op_; }

std::string TorchMlirComputation::to_string() const {
// Since we use the C-MLIR API, we need to use a callback to print.
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
// user_data is a void ptr to some data structure of our choice -- in this
// case, the string stream where we'll be accumulating the strings.
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
*ss_ptr << std::string(part.data, part.length);
};

std::stringstream ss;

// JIT Graph
ss << "JIT Graph: \n" << graph_->toString() << "\n\n";

// MLIR
ss << "MLIR: \n";
mlirOperationPrint(func_op_, print_callback, &ss);
ss << "\n";

// Input/Output Mapping
ss << "Input/Output Alias Mapping: \n";
for (InputOutputAlias input_output_alias : input_output_aliases_) {
ss << "Output: " << input_output_alias.output_index
<< " -> Input param: " << input_output_alias.param_number << std::endl;
}

return ss.str();
}

} // namespace lazy
} // namespace torch
93 changes: 60 additions & 33 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,48 +39,37 @@ class TORCH_API TorchMlirNodeLoweringInterface {
Create(LoweringContext* loctx);
};

class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph);

int parameters_size() const override;

const std::vector<torch::lazy::Shape>& parameter_shapes() const override;

const std::vector<std::string>& parameter_names() const override;

const torch::lazy::Shape& result_shape() const override;

unsigned num_results() const;

MlirOperation func_op() const;

std::string to_string() const;

private:
std::vector<std::string> parameter_names_;
std::vector<Shape> parameter_shapes_;
Shape result_shape_;

MlirOperation func_op_;
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
unsigned num_results_;
};

class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
// Describes an input/output alias as inserted by the SetUpAlias() API.
struct InputOutputAlias {
// Specifies the index of the aliased buffer in the result tuple.
std::vector<int64_t> output_index;
// Specifies the parameter containing the buffer to be aliased.
int64_t param_number;
// Specifies the index of the aliased buffer in the parameter
std::vector<int64_t> param_index;
// Specifies if the alias is a must alias or may alias.
bool must_alias;
};
using InputOutputAliases = std::vector<InputOutputAlias>;

TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);

// Get the shape of the result tuple component, given by index.
torch::lazy::Shape GetResultShape(size_t index) const override;
// Adds a new input/output alias.
void SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index,
bool must_alias = false) override;

// Check if parameter shape matches result at index.
bool CheckResultShape(
const BackendDataPtr& parameter_data, size_t result_idx) override;

// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
Expand Down Expand Up @@ -128,6 +117,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {

void RegisterMlirDialects();

// Holds the input/output alias information populated by the SetUpAlias() API.
InputOutputAliases input_output_aliases_;
std::shared_ptr<torch::jit::Graph> graph_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
Expand All @@ -136,5 +127,41 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
};

class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
public:
using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases;
using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias;

TorchMlirComputation(
MlirOperation func_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph>& graph,
InputOutputAliases input_output_aliases);

int parameters_size() const override;

const std::vector<torch::lazy::Shape>& parameter_shapes() const override;

const std::vector<std::string>& parameter_names() const override;

const torch::lazy::Shape& result_shape() const override;

unsigned num_results() const;

MlirOperation func_op() const;

std::string to_string() const;

private:
std::vector<std::string> parameter_names_;
std::vector<Shape> parameter_shapes_;
Shape result_shape_;

MlirOperation func_op_;
MlirContext mlir_context_;
std::shared_ptr<torch::jit::Graph> graph_;
InputOutputAliases input_output_aliases_;
unsigned num_results_;
};

} // namespace lazy
} // namespace torch

0 comments on commit 8627bed

Please sign in to comment.