Skip to content

Commit

Permalink
Fix ReplaceExprsInScope (pytorch#101)
Browse files Browse the repository at this point in the history
Fix for issue pytorch#88
  • Loading branch information
tlemo authored Jun 18, 2020
1 parent 1355059 commit d060620
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 42 deletions.
12 changes: 8 additions & 4 deletions torch/csrc/jit/codegen/cuda/ir_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,6 @@ struct TORCH_CUDA_API Val : public Statement {
const DataType dtype_;
};

// TODO: We should use this for the following:
// Fusion
// IfThenElse
// ForLoop
struct TORCH_CUDA_API Scope {
public:
const std::vector<Expr*>& exprs() const noexcept {
Expand All @@ -277,6 +273,14 @@ struct TORCH_CUDA_API Scope {
return exprs_.size();
}

auto& operator[](size_t i) {
return exprs_[i];
}

auto& operator[](size_t i) const {
return exprs_[i];
}

// Insert expr before ref
void insert_before(Expr* ref, Expr* expr);

Expand Down
64 changes: 26 additions & 38 deletions torch/csrc/jit/codegen/cuda/lower_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,56 +200,44 @@ struct CloneLoopNest : public OptOutMutator {
};

struct ReplaceExprsInScope : public OptOutDispatch {
private:
std::unordered_map<Expr*, Expr*> replacement_map_;

void handle(Expr* expr) final {
OptOutDispatch::handle(expr);
public:
static void replace(
Expr* scope,
std::unordered_map<Expr*, Expr*> replacement_map) {
ReplaceExprsInScope reis(std::move(replacement_map));
reis.handle(scope);
}

void handle(ForLoop* fl) final {
for (Expr* expr : fl->body().exprs()) {
auto it = replacement_map_.find(expr);
private:
explicit ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> replacement_map)
: replacement_map_(std::move(replacement_map)) {}

void handleScope(Scope& scope) {
for (size_t i = 0; i < scope.size(); ++i) {
const auto it = replacement_map_.find(scope[i]);
if (it == replacement_map_.end()) {
handle(expr);
handle(scope[i]);
continue;
}
fl->body().insert_before(expr, replacement_map_[expr]);
fl->body().erase(expr);
scope[i] = it->second;
}
}

void handle(IfThenElse* ite) final {
for (Expr* expr : ite->body().exprs()) {
auto it = replacement_map_.find(expr);
if (it == replacement_map_.end()) {
handle(expr);
continue;
}
ite->body().insert_before(expr, replacement_map_[expr]);
ite->body().erase(expr);
}
for (Expr* expr : ite->elseBody().exprs()) {
auto it = replacement_map_.find(expr);
if (it == replacement_map_.end()) {
handle(expr);
continue;
}
ite->elseBody().insert_before(expr, replacement_map_[expr]);
ite->elseBody().erase(expr);
}
void handle(Expr* expr) final {
OptOutDispatch::handle(expr);
}

ReplaceExprsInScope(std::unordered_map<Expr*, Expr*> _replacement_map)
: replacement_map_(std::move(_replacement_map)) {}
void handle(ForLoop* fl) final {
handleScope(fl->body());
}

public:
static void replace(
Expr* scope,
std::unordered_map<Expr*, Expr*> replacement_map) {
ReplaceExprsInScope reis(std::move(replacement_map));
reis.handle(scope);
void handle(IfThenElse* ite) final {
handleScope(ite->body());
handleScope(ite->elseBody());
}

private:
std::unordered_map<Expr*, Expr*> replacement_map_;
};

struct FirstInnerMostScope : private OptInDispatch {
Expand Down

0 comments on commit d060620

Please sign in to comment.