Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] [refactor] Slim ir.h and reduce the build time by ~4% #761

Merged
merged 9 commits into from
Apr 13, 2020
135 changes: 135 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "expr.h"
#include "ir.h"
#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN

Expand All @@ -20,4 +21,138 @@ std::string Expr::get_attribute(const std::string &key) const {
return expr->get_attribute(key);
}

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val) {
return Expr::make<TernaryOpExpression>(TernaryOpType::select, cond, true_val,
false_val);
}

Expr operator-(const Expr &expr) {
return Expr::make<UnaryOpExpression>(UnaryOpType::neg, expr);
}

Expr operator~(const Expr &expr) {
return Expr::make<UnaryOpExpression>(UnaryOpType::bit_not, expr);
}

Expr cast(const Expr &input, DataType dt) {
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast, input);
ret->cast_type = dt;
ret->cast_by_value = true;
return Expr(ret);
}

Expr bit_cast(const Expr &input, DataType dt) {
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast, input);
ret->cast_type = dt;
ret->cast_by_value = false;
return Expr(ret);
}

Expr Expr::operator[](const ExprGroup &indices) const {
TI_ASSERT(is<GlobalVariableExpression>() || is<ExternalTensorExpression>());
return Expr::make<GlobalPtrExpression>(*this, indices.loaded());
}

Expr &Expr::operator=(const Expr &o) {
if (get_current_program().current_kernel) {
if (expr == nullptr) {
set(o.eval());
} else if (expr->is_lvalue()) {
current_ast_builder().insert(std::make_unique<FrontendAssignStmt>(
ptr_if_global(*this), load_if_ptr(o)));
} else {
// set(o.eval());
TI_ERROR("Cannot assign to non-lvalue: {}", serialize());
}
} else {
set(o);
}
return *this;
}

Expr Expr::parent() const {
TI_ASSERT(is<GlobalVariableExpression>());
return Expr::make<GlobalVariableExpression>(
cast<GlobalVariableExpression>()->snode->parent);
}

SNode *Expr::snode() const {
TI_ASSERT(is<GlobalVariableExpression>());
return cast<GlobalVariableExpression>()->snode;
}

Expr Expr::operator!() {
return Expr::make<UnaryOpExpression>(UnaryOpType::logic_not, expr);
}

void Expr::declare(DataType dt) {
set(Expr::make<GlobalVariableExpression>(dt, Identifier()));
}

void Expr::set_grad(const Expr &o) {
this->cast<GlobalVariableExpression>()->adjoint.set(o);
}

Expr::Expr(int32 x) : Expr() {
expr = std::make_shared<ConstExpression>(x);
}

Expr::Expr(int64 x) : Expr() {
expr = std::make_shared<ConstExpression>(x);
}

Expr::Expr(float32 x) : Expr() {
expr = std::make_shared<ConstExpression>(x);
}

Expr::Expr(float64 x) : Expr() {
expr = std::make_shared<ConstExpression>(x);
}

Expr::Expr(const Identifier &id) : Expr() {
expr = std::make_shared<IdExpression>(id);
}

Expr Expr::eval() const {
TI_ASSERT(expr != nullptr);
if (is<EvalExpression>()) {
return *this;
}
auto eval_stmt = Stmt::make<FrontendEvalStmt>(*this);
auto eval_expr = Expr::make<EvalExpression>(eval_stmt.get());
eval_stmt->as<FrontendEvalStmt>()->eval_expr.set(eval_expr);
// needed in lower_ast to replace the statement itself with the
// lowered statement
current_ast_builder().insert(std::move(eval_stmt));
return eval_expr;
}

void Expr::operator+=(const Expr &o) {
if (this->atomic) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::add, ptr_if_global(*this), load_if_ptr(o)));
} else {
(*this) = (*this) + o;
}
}

void Expr::operator-=(const Expr &o) {
if (this->atomic) {
current_ast_builder().insert(Stmt::make<FrontendAtomicStmt>(
AtomicOpType::add, *this, -load_if_ptr(o)));
} else {
(*this) = (*this) - o;
}
}

void Expr::operator*=(const Expr &o) {
TI_ASSERT(!this->atomic);
(*this) = (*this) * load_if_ptr(o);
}

void Expr::operator/=(const Expr &o) {
TI_ASSERT(!this->atomic);
(*this) = (*this) / load_if_ptr(o);
}

TLANG_NAMESPACE_END
21 changes: 21 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,25 @@ class Expr {
std::string get_attribute(const std::string &key) const;
};

Expr select(const Expr &cond, const Expr &true_val, const Expr &false_val);

Expr operator-(const Expr &expr);

Expr operator~(const Expr &expr);

// Value cast
Expr cast(const Expr &input, DataType dt);

template <typename T>
Expr cast(const Expr &input) {
return taichi::lang::cast(input, get_data_type<T>());
}

Expr bit_cast(const Expr &input, DataType dt);

template <typename T>
Expr bit_cast(const Expr &input) {
return taichi::lang::bit_cast(input, get_data_type<T>());
}

TLANG_NAMESPACE_END
Loading