Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add an IR Builder with some basic functions #2204

Merged
merged 6 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DecoratorRecorder dec;

FrontendContext::FrontendContext() {
root_node = std::make_unique<Block>();
current_builder = std::make_unique<IRBuilder>(root_node.get());
current_builder = std::make_unique<ASTBuilder>(root_node.get());
}

FrontendForStmt::FrontendForStmt(const Expr &loop_var,
Expand Down
12 changes: 6 additions & 6 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::string snode_access_flag_name(SNodeAccessFlag type) {
}
}

IRBuilder &current_ast_builder() {
ASTBuilder &current_ast_builder() {
return context->builder();
}

Expand All @@ -41,29 +41,29 @@ void DecoratorRecorder::reset() {
strictly_serialized = false;
}

Block *IRBuilder::current_block() {
Block *ASTBuilder::current_block() {
if (stack.empty())
return nullptr;
else
return stack.back();
}

Stmt *IRBuilder::get_last_stmt() {
Stmt *ASTBuilder::get_last_stmt() {
TI_ASSERT(!stack.empty());
return stack.back()->back();
}

void IRBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) {
void ASTBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) {
TI_ASSERT(!stack.empty());
stack.back()->insert(std::move(stmt), location);
}

void IRBuilder::stop_gradient(SNode *snode) {
void ASTBuilder::stop_gradient(SNode *snode) {
TI_ASSERT(!stack.empty());
stack.back()->stop_gradients.push_back(snode);
}

std::unique_ptr<IRBuilder::ScopeGuard> IRBuilder::create_scope(
std::unique_ptr<ASTBuilder::ScopeGuard> ASTBuilder::create_scope(
std::unique_ptr<Block> &list) {
TI_ASSERT(list == nullptr);
list = std::make_unique<Block>();
Expand Down
18 changes: 10 additions & 8 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

TLANG_NAMESPACE_BEGIN

class IRBuilder;
class ASTBuilder;
class IRNode;
class Block;
class Stmt;
Expand Down Expand Up @@ -70,7 +70,7 @@ class MemoryAccessOptions {
#include "taichi/inc/statements.inc.h"
#undef PER_STATEMENT

IRBuilder &current_ast_builder();
ASTBuilder &current_ast_builder();

class DecoratorRecorder {
public:
Expand All @@ -91,13 +91,13 @@ class DecoratorRecorder {

class FrontendContext {
private:
std::unique_ptr<IRBuilder> current_builder;
std::unique_ptr<ASTBuilder> current_builder;
std::unique_ptr<Block> root_node;

public:
FrontendContext();

IRBuilder &builder() {
ASTBuilder &builder() {
return *current_builder;
}

Expand All @@ -110,21 +110,23 @@ class FrontendContext {

extern std::unique_ptr<FrontendContext> context;

class IRBuilder {
// TODO: move to frontend_ir.h
class ASTBuilder {
private:
std::vector<Block *> stack;

public:
IRBuilder(Block *initial) {
ASTBuilder(Block *initial) {
stack.push_back(initial);
}

void insert(std::unique_ptr<Stmt> &&stmt, int location = -1);

struct ScopeGuard {
IRBuilder *builder;
ASTBuilder *builder;
Block *list;
ScopeGuard(IRBuilder *builder, Block *list) : builder(builder), list(list) {
ScopeGuard(ASTBuilder *builder, Block *list)
: builder(builder), list(list) {
builder->stack.push_back(list);
}

Expand Down
77 changes: 77 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"

TLANG_NAMESPACE_BEGIN

IRBuilder::IRBuilder() {
root_ = std::make_unique<Block>();
insert_point_.block = root_->as<Block>();
insert_point_.position = 0;
}

Stmt *IRBuilder::insert(std::unique_ptr<Stmt> &&stmt) {
return insert_point_.block->insert(std::move(stmt), insert_point_.position++);
}

Stmt *IRBuilder::get_int32(int32 value) {
return insert(
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(TypedConstant(
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i32),
value))));
}

Stmt *IRBuilder::get_int64(int64 value) {
return insert(
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(TypedConstant(
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::i64),
value))));
}

Stmt *IRBuilder::get_float32(float32 value) {
return insert(
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(TypedConstant(
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f32),
value))));
}

Stmt *IRBuilder::get_float64(float64 value) {
return insert(
Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(TypedConstant(
TypeFactory::get_instance().get_primitive_type(PrimitiveTypeID::f64),
value))));
}

Stmt *IRBuilder::get_argument(int arg_id, DataType dt, bool is_ptr) {
return insert(Stmt::make<ArgLoadStmt>(arg_id, dt, is_ptr));
}

Stmt *IRBuilder::create_add(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::add, l, r));
}

Stmt *IRBuilder::create_sub(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::sub, l, r));
}

Stmt *IRBuilder::create_mul(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::mul, l, r));
}

Stmt *IRBuilder::create_div(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::div, l, r));
}

Stmt *IRBuilder::create_floordiv(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::floordiv, l, r));
}

Stmt *IRBuilder::create_truediv(Stmt *l, Stmt *r) {
return insert(Stmt::make<BinaryOpStmt>(BinaryOpType::truediv, l, r));
}

template <typename... Args>
Stmt *IRBuilder::create_print(Args &&... args) {
return insert(Stmt::make<PrintStmt>(std::forward(args)));
}

TLANG_NAMESPACE_END
50 changes: 50 additions & 0 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include "taichi/ir/ir.h"

TLANG_NAMESPACE_BEGIN

class IRBuilder {
public:
struct InsertPoint {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does this need to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need to modify the insertion point... I think it's fine to make it private now.

Block *block;
int position;
};

private:
std::unique_ptr<IRNode> root_;
InsertPoint insert_point_;

public:
IRBuilder();

// General inserter. Returns stmt.get().
Stmt *insert(std::unique_ptr<Stmt> &&stmt);

// Constants. TODO: add more types
Stmt *get_int32(int32 value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may also consider a template <typename T> Stmt *get_const(T val) template function and then specialize for different types (or use if constexpr (std::is_same_v<T, int32>)). Just a random idea and the decision is yours :-)

Stmt *get_int64(int64 value);
Stmt *get_float32(float32 value);
Stmt *get_float64(float64 value);

// Kernel arguments.
Stmt *get_argument(int arg_id, DataType dt, bool is_ptr);
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved

// Binary operations. Returns the result.
Stmt *create_add(Stmt *l, Stmt *r);
Stmt *create_sub(Stmt *l, Stmt *r);
Stmt *create_mul(Stmt *l, Stmt *r);

// l / r in C++
Stmt *create_div(Stmt *l, Stmt *r);
// floor(1.0 * l / r) in C++
Stmt *create_floordiv(Stmt *l, Stmt *r);
// 1.0 * l / r in C++
Stmt *create_truediv(Stmt *l, Stmt *r);

// Print values and strings. Arguments can be Stmt* or std::string.
template <typename... Args>
Stmt *create_print(Args &&... args);
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
};

TLANG_NAMESPACE_END
2 changes: 1 addition & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void expr_assign(const Expr &lhs_, const Expr &rhs, std::string tb) {
current_ast_builder().insert(std::move(stmt));
}

std::vector<std::unique_ptr<IRBuilder::ScopeGuard>> scope_stack;
std::vector<std::unique_ptr<ASTBuilder::ScopeGuard>> scope_stack;

void compile_runtimes();
std::string libdevice_path();
Expand Down