Skip to content

Commit

Permalink
[ir][refactor] Move legacy frontend constructs to frontend.h/cpp (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored May 6, 2020
1 parent 09881d4 commit beab17e
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 93 deletions.
64 changes: 64 additions & 0 deletions taichi/ir/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,68 @@ Expr global_new(DataType dt, std::string name) {
return Expr::make<GlobalVariableExpression>(dt, id_expr->id);
}

// Begin: legacy frontend constructs

If::If(const Expr &cond) {
auto stmt_tmp = std::make_unique<FrontendIfStmt>(cond);
stmt = stmt_tmp.get();
current_ast_builder().insert(std::move(stmt_tmp));
}

If::If(const Expr &cond, const std::function<void()> &func) : If(cond) {
Then(func);
}

If &If::Then(const std::function<void()> &func) {
auto _ = current_ast_builder().create_scope(stmt->true_statements);
func();
return *this;
}

If &If::Else(const std::function<void()> &func) {
auto _ = current_ast_builder().create_scope(stmt->false_statements);
func();
return *this;
}

For::For(const Expr &s, const Expr &e, const std::function<void(Expr)> &func) {
auto i = Expr(std::make_shared<IdExpression>());
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func(i);
}

For::For(const Expr &i,
const Expr &s,
const Expr &e,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func();
}

For::For(const ExprGroup &i,
const Expr &global,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, global);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func();
}

While::While(const Expr &cond, const std::function<void()> &func) {
auto while_stmt = std::make_unique<FrontendWhileStmt>(cond);
FrontendWhileStmt *ptr = while_stmt.get();
current_ast_builder().insert(std::move(while_stmt));
auto _ = current_ast_builder().create_scope(ptr->body);
func();
}

// End: legacy frontend constructs

TLANG_NAMESPACE_END
36 changes: 36 additions & 0 deletions taichi/ir/frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,42 @@ inline Expr AssumeInRange(const Expr &expr,
return Expr::make<RangeAssumptionExpression>(expr, base, low, high);
}

// Begin: legacy frontend constructs

class If {
public:
FrontendIfStmt *stmt;

explicit If(const Expr &cond);

If(const Expr &cond, const std::function<void()> &func);

If &Then(const std::function<void()> &func);

If &Else(const std::function<void()> &func);
};

class For {
public:
For(const Expr &i,
const Expr &s,
const Expr &e,
const std::function<void()> &func);

For(const ExprGroup &i,
const Expr &global,
const std::function<void()> &func);

For(const Expr &s, const Expr &e, const std::function<void(Expr)> &func);
};

class While {
public:
While(const Expr &cond, const std::function<void()> &func);
};

// End: legacy frontend constructs

#define Kernel(x) auto &x = get_current_program().kernel(#x)
#define Assert(x) InsertAssert(#x, (x))

Expand Down
61 changes: 0 additions & 61 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <unordered_map>

#include "taichi/ir/frontend.h"
#include "taichi/ir/frontend_ir.h"
#include "taichi/ir/statements.h"

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -910,36 +909,6 @@ std::unique_ptr<ConstStmt> ConstStmt::copy() {
return std::make_unique<ConstStmt>(val);
}

For::For(const Expr &s, const Expr &e, const std::function<void(Expr)> &func) {
auto i = Expr(std::make_shared<IdExpression>());
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func(i);
}

For::For(const Expr &i,
const Expr &s,
const Expr &e,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func();
}

For::For(const ExprGroup &i,
const Expr &global,
const std::function<void()> &func) {
auto stmt_unique = std::make_unique<FrontendForStmt>(i, global);
auto stmt = stmt_unique.get();
current_ast_builder().insert(std::move(stmt_unique));
auto _ = current_ast_builder().create_scope(stmt->body);
func();
}

OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type)
: OffloadedStmt(task_type, nullptr) {
}
Expand Down Expand Up @@ -1005,34 +974,4 @@ bool ContinueStmt::as_return() const {
return false;
}

If::If(const Expr &cond) {
auto stmt_tmp = std::make_unique<FrontendIfStmt>(cond);
stmt = stmt_tmp.get();
current_ast_builder().insert(std::move(stmt_tmp));
}

If::If(const Expr &cond, const std::function<void()> &func) : If(cond) {
Then(func);
}

If &If::Then(const std::function<void()> &func) {
auto _ = current_ast_builder().create_scope(stmt->true_statements);
func();
return *this;
}

If &If::Else(const std::function<void()> &func) {
auto _ = current_ast_builder().create_scope(stmt->false_statements);
func();
return *this;
}

While::While(const Expr &cond, const std::function<void()> &func) {
auto while_stmt = std::make_unique<FrontendWhileStmt>(cond);
FrontendWhileStmt *ptr = while_stmt.get();
current_ast_builder().insert(std::move(while_stmt));
auto _ = current_ast_builder().create_scope(ptr->body);
func();
}

TLANG_NAMESPACE_END
32 changes: 0 additions & 32 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1291,19 +1291,6 @@ class PrintStmt : public Stmt {
DEFINE_ACCEPT
};

class If {
public:
FrontendIfStmt *stmt;

explicit If(const Expr &cond);

If(const Expr &cond, const std::function<void()> &func);

If &Then(const std::function<void()> &func);

If &Else(const std::function<void()> &func);
};

class ConstStmt : public Stmt {
public:
LaneAttribute<TypedConstant> val;
Expand Down Expand Up @@ -1504,25 +1491,6 @@ inline void SLP(int v) {
current_ast_builder().insert(Stmt::make<PragmaSLPStmt>(v));
}

class For {
public:
For(const Expr &i,
const Expr &s,
const Expr &e,
const std::function<void()> &func);

For(const ExprGroup &i,
const Expr &global,
const std::function<void()> &func);

For(const Expr &s, const Expr &e, const std::function<void(Expr)> &func);
};

class While {
public:
While(const Expr &cond, const std::function<void()> &func);
};

Expr Var(const Expr &x);

class VectorElement {
Expand Down

0 comments on commit beab17e

Please sign in to comment.