Skip to content

Commit

Permalink
Refactor Node Lowering (#914)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim authored and henrytwo committed Jul 12, 2022
1 parent 41c2449 commit 770bcaa
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 353 deletions.
27 changes: 21 additions & 6 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
RegisterMlirDialects();
}

Expand All @@ -49,16 +50,31 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
std::forward<c10::ArrayRef<torch::lazy::Node*>>(post_order),
std::forward<Util::EmissionMap>(emit_status)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)),
mlir_context_(mlirContextCreate()) {
lowering_ = TorchMlirNodeLoweringInterface::Create(this);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
CHECK(ok) << "Failed to lower: " << *node;
Lower(node);
}

RegisterMlirDialects();
}

void TorchMlirLoweringContext::Lower(const Node* node) {
if (auto* torch_mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
CHECK(!ops.empty()) << "Failed to lower: " << *node;
CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
} else {
throw std::runtime_error(
"Expected torch::lazy::TorchMlirNode but could not dynamic cast");
}
}

void TorchMlirLoweringContext::SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index, bool must_alias) {
Expand Down Expand Up @@ -136,8 +152,7 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
if (it == emitted_outputs_.end()) {
auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
for (auto node : post_order) {
bool ok = lowering_->Lower(node);
TORCH_CHECK(ok, "Failed to lower: ", node->ToString());
Lower(node);
}
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
Expand Down
20 changes: 3 additions & 17 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@
namespace torch {
namespace lazy {

class TORCH_API TorchMlirNodeLoweringInterface {
/**
* This interface is only needed for legacy ops, and can be removed once all
* ops implement LtcMlirNode->lower().
* */
public:
TorchMlirNodeLoweringInterface() = default;

virtual ~TorchMlirNodeLoweringInterface() = default;

virtual bool Lower(const Node* node) = 0;

static std::unique_ptr<TorchMlirNodeLoweringInterface>
Create(LoweringContext* loctx);
};

class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
public:
// Describes an input/output alias as inserted by the SetUpAlias() API.
Expand All @@ -61,6 +45,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
c10::ArrayRef<torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);

void Lower(const Node* node);

// Adds a new input/output alias.
void SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
Expand Down Expand Up @@ -120,11 +106,11 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext {
// Holds the input/output alias information populated by the SetUpAlias() API.
InputOutputAliases input_output_aliases_;
std::shared_ptr<torch::jit::Graph> graph_;
std::shared_ptr<torch::jit::GraphFunction> function_;
MlirContext mlir_context_;
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
std::vector<torch::jit::Value*> root_tuple_;
OutputMap<torch::jit::Value*> emitted_outputs_;
std::unique_ptr<TorchMlirNodeLoweringInterface> lowering_;
};

class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
Expand Down
6 changes: 0 additions & 6 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ hash_t TorchMlirNode::hash() const { return dag_hash_; }

hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }

TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
return {};
}


OpKind TorchMlirTensorList::ClassOpKind() {
// Note: this OpKind is separate from ltc_ops.h since it would be a circular
// import otherwise
Expand Down
Loading

0 comments on commit 770bcaa

Please sign in to comment.