Skip to content

Commit

Permalink
[wip] Basic fuser pass to select texpr subgraphs
Browse files Browse the repository at this point in the history
  • Loading branch information
bertmaher committed Jan 17, 2020
1 parent d1c7556 commit a9d9919
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
Expand Down
170 changes: 170 additions & 0 deletions torch/csrc/jit/passes/tensorexpr_fuser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#include <torch/csrc/autograd/record_function.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/operator_options.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/pass_manager.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>

using namespace torch::jit;

namespace {

const Symbol& getTensorExprSymbol() {
static Symbol s = Symbol::fromQualString("tensorexpr::Group");
return s;
}

value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* block) {
value_list result;
for (auto i : inputs) {
if (i->node()->owningBlock() == block) {
result.push_back(i);
}
}
// Sort in reverse topological order
std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
return a->node()->isAfter(b->node());
});
return result;
}

bool isSupported(Node* node) {
// TODO:
return node->kind() == Symbol::fromQualString("aten::add");
}

bool canHandle(Node* node, AliasDb& aliasDb) {
if (node->kind() == prim::Constant) {
return true;
}
if (node->kind() == prim::Loop) {
return false; // TODO
}
return isSupported(node);
}

#define REQ(cond) \
if (!(cond)) { \
GRAPH_DEBUG("Failed cond " #cond "\n"); \
return c10::nullopt; \
}

c10::optional<Node*> tryMerge(
Node* consumer,
Node* producer,
AliasDb& aliasDb) {
GRAPH_DEBUG(
"Trying producer ",
producer->kind().toQualString(),
" and consumer ",
consumer->kind().toQualString(),
":\n");

// Symbolic checks
REQ(canHandle(producer, aliasDb));
REQ((canHandle(consumer, aliasDb) || consumer->kind() == getTensorExprSymbol()));

// Alias checks
// Requirement:
// - moveAfterTopologicallyValid(consumer, producer)
// - One of:
// 1) Both are in-place ops
// 2) Consumer is in-place, producer !hasInputWriters
// 3) Producer is in-place, consumer !hasOutputWriters
REQ(aliasDb.moveAfterTopologicallyValid(consumer, producer));

// 1)
if (!(aliasDb.isMutable(consumer) && aliasDb.isMutable(producer))) {
// 2)
if (aliasDb.isMutable(consumer)) {
REQ(!aliasDb.hasInputWriters(producer));
// 3)
} else if (aliasDb.isMutable(producer)) {
REQ(!aliasDb.hasOutputWriters(consumer));
}
}

if (!consumer->hasAttribute(attr::Subgraph) &&
consumer->kind() != getTensorExprSymbol()) {
consumer = SubgraphUtils::createSingletonSubgraph(consumer, getTensorExprSymbol());
}
if (producer->kind() == prim::Constant) {
auto& subgraph = consumer->g(attr::Subgraph);
Node* in_const = subgraph->createClone(producer, [](Value*) -> Value* {
throw std::runtime_error("unexpected input");
});
subgraph->insertNode(in_const);
} else {
SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
}
return consumer;
}
#undef REQ

std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
AliasDb& aliasDb,
Block* block) {
auto inputs = sortReverseTopological(consumer->inputs(), block);
for (auto input : inputs) {
if (auto group = tryMerge(consumer, input->node(), aliasDb)) {
// we successfully merged, so the new group's `inputs` may have
// changed. So rescan the new group for more merging opportunities.
return {group.value()->reverseIterator(), true};
}
}
return {++consumer->reverseIterator(), false};
}

void fuseTensorExprs(std::shared_ptr<Graph>& graph) {
std::cout << "Entering TExprFuser\n";
std::cout << *graph;

AliasDb aliasDb(graph);
auto block = graph->block();

bool any_changed = true;
while (any_changed) {
any_changed = false;
for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
bool changed;
std::tie(it, changed) = scanNode(*it, aliasDb, block);
any_changed |= changed;
}
}

EliminateCommonSubexpression(graph);
EliminateDeadCode(graph);

std::cout << "Finishing TExprFuser\n";
std::cout << *graph;
}

Operation createTensorExprOp(const Node* node) {
return [](Stack& stack) {
RECORD_FUNCTION("TensorExprGroup", std::vector<c10::IValue>());
// Do something?
return 0;
};
}

c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) {
auto options = c10::OperatorOptions();
options.setAliasAnalysis(k);
return options;
}

RegisterOperators TensorExprOps({
torch::jit::Operator(
getTensorExprSymbol(),
createTensorExprOp,
getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)
),
});

RegisterPass pass(fuseTensorExprs);

} // namespace

0 comments on commit a9d9919

Please sign in to comment.