From b5283a973863307824ecff0b3c79a71adefadcd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E4=BA=8E=E6=96=8C?= <1931127624@qq.com> Date: Thu, 9 Apr 2020 20:44:53 +0800 Subject: [PATCH] [ir] [refactor] BasicStmtVisitor includes Frontend sstatements (#732) --- taichi/analysis/detect_for_with_break.cpp | 21 +++++++-------------- taichi/ir/basic_stmt_visitor.cpp | 15 +++++++++++++++ taichi/ir/visitors.h | 6 ++++++ 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/taichi/analysis/detect_for_with_break.cpp b/taichi/analysis/detect_for_with_break.cpp index d3a3cde111972..ad818f357276d 100644 --- a/taichi/analysis/detect_for_with_break.cpp +++ b/taichi/analysis/detect_for_with_break.cpp @@ -15,6 +15,13 @@ class DetectForWithBreak : public BasicStmtVisitor { DetectForWithBreak(IRNode *root) : root(root) { } + void visit(FrontendBreakStmt *stmt) override { + TI_ASSERT_INFO(loop_stack.size() != 0, "break statement out of loop scope"); + auto loop = loop_stack.back(); + if (loop->is()) + fors_with_break.insert(loop); + } + void visit(FrontendWhileStmt *stmt) override { loop_stack.push_back(stmt); stmt->body->accept(this); @@ -27,20 +34,6 @@ class DetectForWithBreak : public BasicStmtVisitor { loop_stack.pop_back(); } - void visit(FrontendIfStmt *stmt) override { - if (stmt->true_statements) - stmt->true_statements->accept(this); - if (stmt->false_statements) - stmt->false_statements->accept(this); - } - - void visit(FrontendBreakStmt *stmt) override { - TI_ASSERT_INFO(loop_stack.size() != 0, "break statement out of loop scope"); - auto loop = loop_stack.back(); - if (loop->is()) - fors_with_break.insert(loop); - } - std::unordered_set run() { root->accept(this); return fors_with_break; diff --git a/taichi/ir/basic_stmt_visitor.cpp b/taichi/ir/basic_stmt_visitor.cpp index b68bdc71e6cc0..8580a2c4b992f 100644 --- a/taichi/ir/basic_stmt_visitor.cpp +++ b/taichi/ir/basic_stmt_visitor.cpp @@ -43,4 +43,19 @@ void BasicStmtVisitor::visit(OffloadedStmt *stmt) { stmt->body->accept(this); } +void BasicStmtVisitor::visit(FrontendWhileStmt *stmt) { + stmt->body->accept(this); +} + +void BasicStmtVisitor::visit(FrontendForStmt *stmt) { + stmt->body->accept(this); +} + +void BasicStmtVisitor::visit(FrontendIfStmt *stmt) { + if (stmt->true_statements) + stmt->true_statements->accept(this); + if (stmt->false_statements) + stmt->false_statements->accept(this); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/visitors.h b/taichi/ir/visitors.h index 50c47c077d669..e1ab43caca408 100644 --- a/taichi/ir/visitors.h +++ b/taichi/ir/visitors.h @@ -22,6 +22,12 @@ class BasicStmtVisitor : public IRVisitor { void visit(StructForStmt *for_stmt) override; void visit(OffloadedStmt *stmt) override; + + void visit(FrontendWhileStmt *stmt) override; + + void visit(FrontendForStmt *stmt) override; + + void visit(FrontendIfStmt *stmt) override; }; TLANG_NAMESPACE_END