From e6b0fc2aea078bdad316bde028dd118d6294e931 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 5 May 2020 08:20:37 +0900 Subject: [PATCH] [ir][refactor] Move all Expression subclasses to `frontend_ir.h` (#919) --- taichi/ir/frontend_ir.h | 340 ++++++++++++++++++++++++++++++++++++++++ taichi/ir/ir.h | 333 --------------------------------------- 2 files changed, 340 insertions(+), 333 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 151de699c02ba..4df05e6d2d78f 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -1,11 +1,16 @@ #pragma once +#include +#include + #include "taichi/lang_util.h" #include "taichi/ir/ir.h" #include "taichi/ir/expr.h" TLANG_NAMESPACE_BEGIN +// Frontend Statements + class FrontendAllocaStmt : public Stmt { public: Identifier ident; @@ -193,4 +198,339 @@ class FrontendWhileStmt : public Stmt { DEFINE_ACCEPT }; +// Expressions + +class ArgLoadExpression : public Expression { + public: + int arg_id; + + ArgLoadExpression(int arg_id) : arg_id(arg_id) { + } + + std::string serialize() override { + return fmt::format("arg[{}]", arg_id); + } + + void flatten(FlattenContext *ctx) override { + auto ran = std::make_unique(arg_id); + ctx->push_back(std::move(ran)); + stmt = ctx->back_stmt(); + } +}; + +class RandExpression : public Expression { + public: + DataType dt; + + RandExpression(DataType dt) : dt(dt) { + } + + std::string serialize() override { + return fmt::format("rand<{}>()", data_type_name(dt)); + } + + void flatten(FlattenContext *ctx) override { + auto ran = std::make_unique(dt); + ctx->push_back(std::move(ran)); + stmt = ctx->back_stmt(); + } +}; + +class UnaryOpExpression : public Expression { + public: + UnaryOpType type; + Expr operand; + DataType cast_type; + + UnaryOpExpression(UnaryOpType type, const Expr &operand) + : type(type), operand(smart_load(operand)) { + cast_type = DataType::unknown; + } + + bool is_cast() const; + + std::string serialize() override; + + void flatten(FlattenContext *ctx) override; +}; + +class BinaryOpExpression : public Expression { + public: + BinaryOpType type; + Expr lhs, rhs; + + BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) + : type(type) { + this->lhs.set(smart_load(lhs)); + this->rhs.set(smart_load(rhs)); + } + + std::string serialize() override { + return fmt::format("({} {} {})", lhs->serialize(), + binary_op_type_symbol(type), rhs->serialize()); + } + + void flatten(FlattenContext *ctx) override { + // if (stmt) + // return; + lhs->flatten(ctx); + rhs->flatten(ctx); + ctx->push_back(std::make_unique(type, lhs->stmt, rhs->stmt)); + ctx->stmts.back()->tb = tb; + stmt = ctx->back_stmt(); + } +}; + +class TernaryOpExpression : public Expression { + public: + TernaryOpType type; + Expr op1, op2, op3; + + TernaryOpExpression(TernaryOpType type, + const Expr &op1, + const Expr &op2, + const Expr &op3) + : type(type) { + this->op1.set(load_if_ptr(op1)); + this->op2.set(load_if_ptr(op2)); + this->op3.set(load_if_ptr(op3)); + } + + std::string serialize() override { + return fmt::format("{}({} {} {})", ternary_type_name(type), + op1->serialize(), op2->serialize(), op3->serialize()); + } + + void flatten(FlattenContext *ctx) override { + // if (stmt) + // return; + op1->flatten(ctx); + op2->flatten(ctx); + op3->flatten(ctx); + ctx->push_back( + std::make_unique(type, op1->stmt, op2->stmt, op3->stmt)); + stmt = ctx->back_stmt(); + } +}; + +class ExternalTensorExpression : public Expression { + public: + DataType dt; + int dim; + int arg_id; + + ExternalTensorExpression(const DataType &dt, int dim, int arg_id) + : dt(dt), dim(dim), arg_id(arg_id) { + set_attribute("dim", std::to_string(dim)); + } + + std::string serialize() override { + return fmt::format("{}d_ext_arr", dim); + } + + void flatten(FlattenContext *ctx) override { + auto ptr = Stmt::make(arg_id, true); + ctx->push_back(std::move(ptr)); + stmt = ctx->back_stmt(); + } +}; + +class GlobalVariableExpression : public Expression { + public: + Identifier ident; + DataType dt; + SNode *snode; + bool has_ambient; + TypedConstant ambient_value; + bool is_primal; + Expr adjoint; + + GlobalVariableExpression(DataType dt, const Identifier &ident) + : ident(ident), dt(dt) { + snode = nullptr; + has_ambient = false; + is_primal = true; + } + + GlobalVariableExpression(SNode *snode) : snode(snode) { + dt = snode->dt; + has_ambient = false; + is_primal = true; + } + + void set_snode(SNode *snode) { + this->snode = snode; + set_attribute("dim", std::to_string(snode->num_active_indices)); + } + + std::string serialize() override { + return "#" + ident.name(); + } + + void flatten(FlattenContext *ctx) override { + TI_ASSERT(snode->num_active_indices == 0); + auto ptr = Stmt::make(LaneAttribute(snode), + std::vector()); + ctx->push_back(std::move(ptr)); + } +}; + +class GlobalPtrExpression : public Expression { + public: + Expr var; + ExprGroup indices; + + GlobalPtrExpression(const Expr &var, const ExprGroup &indices) + : var(var), indices(indices) { + } + + std::string serialize() override; + + void flatten(FlattenContext *ctx) override; + + bool is_lvalue() const override { + return true; + } +}; + +class EvalExpression : public Expression { + public: + Stmt *stmt_ptr; + int stmt_id; + EvalExpression(Stmt *stmt) : stmt_ptr(stmt), stmt_id(stmt_ptr->id) { + // cache stmt->id since it may be released later + } + + std::string serialize() override { + return fmt::format("%{}", stmt_id); + } + + void flatten(FlattenContext *ctx) override { + stmt = stmt_ptr; + } +}; + +class RangeAssumptionExpression : public Expression { + public: + Expr input, base; + int low, high; + + RangeAssumptionExpression(const Expr &input, + const Expr &base, + int low, + int high) + : input(input), base(base), low(low), high(high) { + } + + std::string serialize() override { + return fmt::format("assume_in_range({}{:+d} <= ({}) < {}{:+d})", + base.serialize(), low, input.serialize(), + base.serialize(), high); + } + + void flatten(FlattenContext *ctx) override { + input->flatten(ctx); + base->flatten(ctx); + ctx->push_back( + Stmt::make(input->stmt, base->stmt, low, high)); + stmt = ctx->back_stmt(); + } +}; + +class IdExpression : public Expression { + public: + Identifier id; + IdExpression(const std::string &name = "") : id(name) { + } + IdExpression(const Identifier &id) : id(id) { + } + + std::string serialize() override { + return id.name(); + } + + void flatten(FlattenContext *ctx) override { + ctx->push_back(std::make_unique( + LocalAddress(ctx->current_block->lookup_var(id), 0))); + stmt = ctx->back_stmt(); + } + + bool is_lvalue() const override { + return true; + } +}; + +// ti.atomic_*() is an expression with side effect. +class AtomicOpExpression : public Expression { + public: + AtomicOpType op_type; + Expr dest, val; + + AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val) + : op_type(op_type), dest(dest), val(val) { + } + + std::string serialize() override; + + void flatten(FlattenContext *ctx) override; +}; + +class SNodeOpExpression : public Expression { + public: + SNode *snode; + SNodeOpType op_type; + ExprGroup indices; + Expr value; + + SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices) + : snode(snode), op_type(op_type), indices(indices) { + } + + SNodeOpExpression(SNode *snode, + SNodeOpType op_type, + const ExprGroup &indices, + const Expr &value) + : snode(snode), op_type(op_type), indices(indices), value(value) { + } + + std::string serialize() override; + + void flatten(FlattenContext *ctx) override; +}; + +class GlobalLoadExpression : public Expression { + public: + Expr ptr; + GlobalLoadExpression(const Expr &ptr) : ptr(ptr) { + } + + std::string serialize() override { + return "gbl load " + ptr.serialize(); + } + + void flatten(FlattenContext *ctx) override { + ptr->flatten(ctx); + ctx->push_back(std::make_unique(ptr->stmt)); + stmt = ctx->back_stmt(); + } +}; + +class ConstExpression : public Expression { + public: + TypedConstant val; + + template + ConstExpression(const T &x) : val(x) { + } + + std::string serialize() override { + return val.stringify(); + } + + void flatten(FlattenContext *ctx) override { + ctx->push_back(Stmt::make(val)); + stmt = ctx->back_stmt(); + } +}; + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index d180ceeaef6cb..e80f17fe6ad5e 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -934,24 +934,6 @@ class ArgLoadStmt : public Stmt { DEFINE_ACCEPT }; -class ArgLoadExpression : public Expression { - public: - int arg_id; - - ArgLoadExpression(int arg_id) : arg_id(arg_id) { - } - - std::string serialize() override { - return fmt::format("arg[{}]", arg_id); - } - - void flatten(FlattenContext *ctx) override { - auto ran = std::make_unique(arg_id); - ctx->push_back(std::move(ran)); - stmt = ctx->back_stmt(); - } -}; - // For return values class ArgStoreStmt : public Stmt { public: @@ -986,42 +968,6 @@ class RandStmt : public Stmt { DEFINE_ACCEPT }; -class RandExpression : public Expression { - public: - DataType dt; - - RandExpression(DataType dt) : dt(dt) { - } - - std::string serialize() override { - return fmt::format("rand<{}>()", data_type_name(dt)); - } - - void flatten(FlattenContext *ctx) override { - auto ran = std::make_unique(dt); - ctx->push_back(std::move(ran)); - stmt = ctx->back_stmt(); - } -}; - -class UnaryOpExpression : public Expression { - public: - UnaryOpType type; - Expr operand; - DataType cast_type; - - UnaryOpExpression(UnaryOpType type, const Expr &operand) - : type(type), operand(smart_load(operand)) { - cast_type = DataType::unknown; - } - - bool is_cast() const; - - std::string serialize() override; - - void flatten(FlattenContext *ctx) override; -}; - class BinaryOpStmt : public Stmt { public: BinaryOpType op_type; @@ -1077,65 +1023,6 @@ class AtomicOpStmt : public Stmt { DEFINE_ACCEPT }; -class BinaryOpExpression : public Expression { - public: - BinaryOpType type; - Expr lhs, rhs; - - BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) - : type(type) { - this->lhs.set(smart_load(lhs)); - this->rhs.set(smart_load(rhs)); - } - - std::string serialize() override { - return fmt::format("({} {} {})", lhs->serialize(), - binary_op_type_symbol(type), rhs->serialize()); - } - - void flatten(FlattenContext *ctx) override { - // if (stmt) - // return; - lhs->flatten(ctx); - rhs->flatten(ctx); - ctx->push_back(std::make_unique(type, lhs->stmt, rhs->stmt)); - ctx->stmts.back()->tb = tb; - stmt = ctx->back_stmt(); - } -}; - -class TernaryOpExpression : public Expression { - public: - TernaryOpType type; - Expr op1, op2, op3; - - TernaryOpExpression(TernaryOpType type, - const Expr &op1, - const Expr &op2, - const Expr &op3) - : type(type) { - this->op1.set(load_if_ptr(op1)); - this->op2.set(load_if_ptr(op2)); - this->op3.set(load_if_ptr(op3)); - } - - std::string serialize() override { - return fmt::format("{}({} {} {})", ternary_type_name(type), - op1->serialize(), op2->serialize(), op3->serialize()); - } - - void flatten(FlattenContext *ctx) override { - // if (stmt) - // return; - op1->flatten(ctx); - op2->flatten(ctx); - op3->flatten(ctx); - ctx->push_back( - std::make_unique(type, op1->stmt, op2->stmt, op3->stmt)); - stmt = ctx->back_stmt(); - } -}; - class ExternalPtrStmt : public Stmt { public: LaneAttribute base_ptrs; @@ -1171,86 +1058,6 @@ class GlobalPtrStmt : public Stmt { DEFINE_ACCEPT }; -class ExternalTensorExpression : public Expression { - public: - DataType dt; - int dim; - int arg_id; - - ExternalTensorExpression(const DataType &dt, int dim, int arg_id) - : dt(dt), dim(dim), arg_id(arg_id) { - set_attribute("dim", std::to_string(dim)); - } - - std::string serialize() override { - return fmt::format("{}d_ext_arr", dim); - } - - void flatten(FlattenContext *ctx) override { - auto ptr = Stmt::make(arg_id, true); - ctx->push_back(std::move(ptr)); - stmt = ctx->back_stmt(); - } -}; - -class GlobalVariableExpression : public Expression { - public: - Identifier ident; - DataType dt; - SNode *snode; - bool has_ambient; - TypedConstant ambient_value; - bool is_primal; - Expr adjoint; - - GlobalVariableExpression(DataType dt, const Identifier &ident) - : ident(ident), dt(dt) { - snode = nullptr; - has_ambient = false; - is_primal = true; - } - - GlobalVariableExpression(SNode *snode) : snode(snode) { - dt = snode->dt; - has_ambient = false; - is_primal = true; - } - - void set_snode(SNode *snode) { - this->snode = snode; - set_attribute("dim", std::to_string(snode->num_active_indices)); - } - - std::string serialize() override { - return "#" + ident.name(); - } - - void flatten(FlattenContext *ctx) override { - TI_ASSERT(snode->num_active_indices == 0); - auto ptr = Stmt::make(LaneAttribute(snode), - std::vector()); - ctx->push_back(std::move(ptr)); - } -}; - -class GlobalPtrExpression : public Expression { - public: - Expr var; - ExprGroup indices; - - GlobalPtrExpression(const Expr &var, const ExprGroup &indices) - : var(var), indices(indices) { - } - - std::string serialize() override; - - void flatten(FlattenContext *ctx) override; - - bool is_lvalue() const override { - return true; - } -}; - #include "expression.h" class Block : public IRNode { @@ -1666,146 +1473,6 @@ class WhileStmt : public Stmt { void Print_(const Expr &a, const std::string &str); -class EvalExpression : public Expression { - public: - Stmt *stmt_ptr; - int stmt_id; - EvalExpression(Stmt *stmt) : stmt_ptr(stmt), stmt_id(stmt_ptr->id) { - // cache stmt->id since it may be released later - } - - std::string serialize() override { - return fmt::format("%{}", stmt_id); - } - - void flatten(FlattenContext *ctx) override { - stmt = stmt_ptr; - } -}; - -class RangeAssumptionExpression : public Expression { - public: - Expr input, base; - int low, high; - - RangeAssumptionExpression(const Expr &input, - const Expr &base, - int low, - int high) - : input(input), base(base), low(low), high(high) { - } - - std::string serialize() override { - return fmt::format("assume_in_range({}{:+d} <= ({}) < {}{:+d})", - base.serialize(), low, input.serialize(), - base.serialize(), high); - } - - void flatten(FlattenContext *ctx) override { - input->flatten(ctx); - base->flatten(ctx); - ctx->push_back( - Stmt::make(input->stmt, base->stmt, low, high)); - stmt = ctx->back_stmt(); - } -}; - -class IdExpression : public Expression { - public: - Identifier id; - IdExpression(const std::string &name = "") : id(name) { - } - IdExpression(const Identifier &id) : id(id) { - } - - std::string serialize() override { - return id.name(); - } - - void flatten(FlattenContext *ctx) override { - ctx->push_back(std::make_unique( - LocalAddress(ctx->current_block->lookup_var(id), 0))); - stmt = ctx->back_stmt(); - } - - bool is_lvalue() const override { - return true; - } -}; - -// ti.atomic_*() is an expression with side effect. -class AtomicOpExpression : public Expression { - public: - AtomicOpType op_type; - Expr dest, val; - - AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val) - : op_type(op_type), dest(dest), val(val) { - } - - std::string serialize() override; - - void flatten(FlattenContext *ctx) override; -}; - -class SNodeOpExpression : public Expression { - public: - SNode *snode; - SNodeOpType op_type; - ExprGroup indices; - Expr value; - - SNodeOpExpression(SNode *snode, SNodeOpType op_type, const ExprGroup &indices) - : snode(snode), op_type(op_type), indices(indices) { - } - - SNodeOpExpression(SNode *snode, - SNodeOpType op_type, - const ExprGroup &indices, - const Expr &value) - : snode(snode), op_type(op_type), indices(indices), value(value) { - } - - std::string serialize() override; - - void flatten(FlattenContext *ctx) override; -}; - -class GlobalLoadExpression : public Expression { - public: - Expr ptr; - GlobalLoadExpression(const Expr &ptr) : ptr(ptr) { - } - - std::string serialize() override { - return "gbl load " + ptr.serialize(); - } - - void flatten(FlattenContext *ctx) override { - ptr->flatten(ctx); - ctx->push_back(std::make_unique(ptr->stmt)); - stmt = ctx->back_stmt(); - } -}; - -class ConstExpression : public Expression { - public: - TypedConstant val; - - template - ConstExpression(const T &x) : val(x) { - } - - std::string serialize() override { - return val.stringify(); - } - - void flatten(FlattenContext *ctx) override { - ctx->push_back(Stmt::make(val)); - stmt = ctx->back_stmt(); - } -}; - extern DecoratorRecorder dec; inline void Vectorize(int v) {