diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h new file mode 100644 index 0000000000000..faba902169142 --- /dev/null +++ b/taichi/inc/expressions.inc.h @@ -0,0 +1,21 @@ +PER_EXPRESSION(ArgLoadExpression) +PER_EXPRESSION(RandExpression) +PER_EXPRESSION(UnaryOpExpression) +PER_EXPRESSION(BinaryOpExpression) +PER_EXPRESSION(TernaryOpExpression) +PER_EXPRESSION(InternalFuncCallExpression) +PER_EXPRESSION(ExternalTensorExpression) +PER_EXPRESSION(GlobalVariableExpression) +PER_EXPRESSION(GlobalPtrExpression) +PER_EXPRESSION(TensorElementExpression) +PER_EXPRESSION(RangeAssumptionExpression) +PER_EXPRESSION(LoopUniqueExpression) +PER_EXPRESSION(IdExpression) +PER_EXPRESSION(AtomicOpExpression) +PER_EXPRESSION(SNodeOpExpression) +PER_EXPRESSION(ConstExpression) +PER_EXPRESSION(ExternalTensorShapeAlongAxisExpression) +PER_EXPRESSION(FuncCallExpression) +PER_EXPRESSION(MeshPatchIndexExpression) +PER_EXPRESSION(MeshRelationAccessExpression) +PER_EXPRESSION(MeshIndexConversionExpression) diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index 20e08342daf14..ff7ae1385d5dc 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -6,17 +6,6 @@ TLANG_NAMESPACE_BEGIN -void Expr::serialize(std::ostream &ss) const { - TI_ASSERT(expr); - expr->serialize(ss); -} - -std::string Expr::serialize() const { - std::stringstream ss; - serialize(ss); - return ss.str(); -} - void Expr::set_tb(const std::string &tb) { expr->tb = tb; } diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index be970cd7c1267..1965492b985ee 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -83,9 +83,6 @@ class Expr { Expr operator[](const ExprGroup &indices) const; - std::string serialize() const; - void serialize(std::ostream &ss) const; - Expr operator!(); template diff --git a/taichi/ir/expression.cpp b/taichi/ir/expression.cpp index 493ace469d00f..f34f90f2384f9 100644 --- a/taichi/ir/expression.cpp +++ b/taichi/ir/expression.cpp @@ -11,20 +11,5 @@ std::string Expression::get_attribute(const std::string &key) const { } } -void ExprGroup::serialize(std::ostream &ss) const { - for (int i = 0; i < (int)exprs.size(); i++) { - exprs[i].serialize(ss); - if (i + 1 < (int)exprs.size()) { - ss << ", "; - } - } -} - -std::string ExprGroup::serialize() const { - std::stringstream ss; - serialize(ss); - return ss.str(); -} - } // namespace lang } // namespace taichi diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index a59ee3521d9e7..09ede3cbbbbd7 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -7,6 +7,8 @@ TLANG_NAMESPACE_BEGIN +class ExpressionVisitor; + // always a tree - used as rvalues class Expression { public: @@ -42,7 +44,7 @@ class Expression { // implemented } - virtual void serialize(std::ostream &ss) = 0; + virtual void accept(ExpressionVisitor *visitor) = 0; virtual void flatten(FlattenContext *ctx) { TI_NOT_IMPLEMENTED; @@ -110,10 +112,6 @@ class ExprGroup { Expr &operator[](int i) { return exprs[i]; } - - void serialize(std::ostream &ss) const; - - std::string serialize() const; }; inline ExprGroup operator,(const Expr &a, const Expr &b) { @@ -124,4 +122,53 @@ inline ExprGroup operator,(const ExprGroup &a, const Expr &b) { return ExprGroup(a, b); } +#define PER_EXPRESSION(x) class x; +#include "taichi/inc/expressions.inc.h" +#undef PER_EXPRESSION + +class ExpressionVisitor { + public: + ExpressionVisitor(bool allow_undefined_visitor = false, + bool invoke_default_visitor = false) + : allow_undefined_visitor_(allow_undefined_visitor), + invoke_default_visitor_(invoke_default_visitor) { + } + + virtual ~ExpressionVisitor() = default; + + virtual void visit(ExprGroup &expr_group) = 0; + + void visit(Expr &expr) { + expr.expr->accept(this); + } + + virtual void visit(Expression *expr) { + if (!allow_undefined_visitor_) { + TI_ERROR("missing visitor function"); + } + } + +#define DEFINE_VISIT(T) \ + virtual void visit(T *expr) { \ + if (allow_undefined_visitor_) { \ + if (invoke_default_visitor_) \ + visit((Expression *)expr); \ + } else \ + TI_NOT_IMPLEMENTED; \ + } + +#define PER_EXPRESSION(x) DEFINE_VISIT(x) +#include "taichi/inc/expressions.inc.h" +#undef PER_EXPRESSION +#undef DEFINE_VISIT + private: + bool allow_undefined_visitor_{false}; + bool invoke_default_visitor_{false}; +}; + +#define TI_DEFINE_ACCEPT_FOR_EXPRESSION \ + void accept(ExpressionVisitor *visitor) override { \ + visitor->visit(this); \ + } + TLANG_NAMESPACE_END diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h new file mode 100644 index 0000000000000..eb8963f3e077f --- /dev/null +++ b/taichi/ir/expression_printer.h @@ -0,0 +1,251 @@ +#pragma once + +#include "taichi/ir/expr.h" +#include "taichi/ir/expression.h" +#include "taichi/ir/frontend_ir.h" + +namespace taichi { +namespace lang { + +class ExpressionHumanFriendlyPrinter : public ExpressionVisitor { + public: + ExpressionHumanFriendlyPrinter(std::ostream *os = nullptr) : os_(os) { + } + + void set_ostream(std::ostream *os) { + os_ = os; + } + + std::ostream &get_ostream() { + TI_ASSERT(os_); + return *os_; + } + + void visit(ExprGroup &expr_group) override { + emit_vector(expr_group.exprs); + } + + void visit(ArgLoadExpression *expr) override { + emit( + fmt::format("arg[{}] (dt={})", expr->arg_id, data_type_name(expr->dt))); + } + + void visit(RandExpression *expr) override { + emit(fmt::format("rand<{}>()", data_type_name(expr->dt))); + } + + void visit(UnaryOpExpression *expr) override { + emit('('); + if (expr->is_cast()) { + emit(expr->type == UnaryOpType::cast_value ? "" : "reinterpret_"); + emit(unary_op_type_name(expr->type)); + emit('<', data_type_name(expr->cast_type), "> "); + } else { + emit(unary_op_type_name(expr->type), ' '); + } + expr->operand->accept(this); + emit(')'); + } + + void visit(BinaryOpExpression *expr) override { + emit('('); + expr->lhs->accept(this); + emit(' ', binary_op_type_symbol(expr->type), ' '); + expr->rhs->accept(this); + emit(')'); + } + + void visit(TernaryOpExpression *expr) override { + emit(ternary_type_name(expr->type), '('); + expr->op1->accept(this); + emit(' '); + expr->op2->accept(this); + emit(' '); + expr->op3->accept(this); + emit(')'); + } + + void visit(InternalFuncCallExpression *expr) override { + emit("internal call ", expr->func_name, '('); + if (expr->with_runtime_context) { + emit("runtime, "); + } + emit_vector(expr->args); + emit(')'); + } + + void visit(ExternalTensorExpression *expr) override { + emit(fmt::format("{}d_ext_arr (element_dim={}, dt={})", expr->dim, + expr->element_dim, expr->dt->to_string())); + } + + void visit(GlobalVariableExpression *expr) override { + emit("#", expr->ident.name()); + if (expr->snode) { + emit( + fmt::format(" (snode={})", expr->snode->get_node_type_name_hinted())); + } else { + emit(fmt::format(" (dt={})", expr->dt->to_string())); + } + } + + void visit(GlobalPtrExpression *expr) override { + if (expr->snode) { + emit(expr->snode->get_node_type_name_hinted()); + } else { + expr->var->accept(this); + } + emit('['); + emit_vector(expr->indices.exprs); + emit(']'); + } + + void visit(TensorElementExpression *expr) override { + expr->var->accept(this); + emit('['); + emit_vector(expr->indices.exprs); + emit("] ("); + emit_vector(expr->shape); + emit(", stride = ", expr->stride); + emit(')'); + } + + void visit(RangeAssumptionExpression *expr) override { + emit("assume_in_range({"); + expr->base->accept(this); + emit(fmt::format("{:+d}", expr->low), " <= ("); + expr->input->accept(this); + emit(") < "); + expr->base->accept(this); + emit(fmt::format("{:+d})", expr->high)); + } + + void visit(LoopUniqueExpression *expr) override { + emit("loop_unique("); + expr->input->accept(this); + if (!expr->covers.empty()) { + emit(", covers=["); + emit_vector(expr->covers); + emit(']'); + } + emit(')'); + } + + void visit(IdExpression *expr) override { + emit(expr->id.name()); + } + + void visit(AtomicOpExpression *expr) override { + const auto op_type = (std::size_t)expr->op_type; + constexpr const char *names_table[] = { + "atomic_add", "atomic_sub", "atomic_min", "atomic_max", + "atomic_bit_and", "atomic_bit_or", "atomic_bit_xor", + }; + if (op_type > std::size(names_table)) { + // min/max not supported in the LLVM backend yet. + TI_NOT_IMPLEMENTED; + } + emit(names_table[op_type], '('); + expr->dest->accept(this); + emit(", "); + expr->val->accept(this); + emit(")"); + } + + void visit(SNodeOpExpression *expr) override { + emit(snode_op_type_name(expr->op_type)); + emit('(', expr->snode->get_node_type_name_hinted(), ", ["); + emit_vector(expr->indices.exprs); + emit("]"); + if (expr->value.expr) { + emit(' '); + expr->value->accept(this); + } + emit(')'); + } + + void visit(ConstExpression *expr) override { + emit(expr->val.stringify()); + } + + void visit(ExternalTensorShapeAlongAxisExpression *expr) override { + emit("external_tensor_shape_along_axis("); + expr->ptr->accept(this); + emit(", ", expr->axis, ')'); + } + + void visit(FuncCallExpression *expr) override { + emit("func_call(\"", expr->func->func_key.get_full_name(), "\", "); + emit_vector(expr->args.exprs); + emit(')'); + } + + void visit(MeshPatchIndexExpression *expr) override { + emit("mesh_patch_idx()"); + } + + void visit(MeshRelationAccessExpression *expr) override { + if (expr->neighbor_idx) { + emit("mesh_relation_access("); + expr->mesh_idx->accept(this); + emit(", ", mesh::element_type_name(expr->to_type), '['); + expr->neighbor_idx->accept(this); + emit("])"); + } else { + emit("mesh_relation_size("); + expr->mesh_idx->accept(this); + emit(", ", mesh::element_type_name(expr->to_type), ')'); + } + } + + void visit(MeshIndexConversionExpression *expr) override { + emit("mesh_index_conversion(", mesh::conv_type_name(expr->conv_type), ", ", + mesh::element_type_name(expr->idx_type), ", "); + expr->idx->accept(this); + emit(")"); + } + + static std::string expr_to_string(Expr &expr) { + std::ostringstream oss; + ExpressionHumanFriendlyPrinter printer(&oss); + expr->accept(&printer); + return oss.str(); + } + + private: + template + void emit(Args &&... args) { + TI_ASSERT(os_); + (*os_ << ... << std::forward(args)); + } + + template + void emit_vector(std::vector &v) { + if (!v.empty()) { + emit_element(v[0]); + const auto size = v.size(); + for (std::size_t i = 1; i < size; ++i) { + emit(", "); + emit_element(v[i]); + } + } + } + + template + void emit_element(D &&e) { + using T = + typename std::remove_cv::type>::type; + if constexpr (std::is_same_v) { + e->accept(this); + } else if constexpr (std::is_same_v) { + emit(e->get_node_type_name_hinted()); + } else { + emit(std::forward(e)); + } + } + + std::ostream *os_{nullptr}; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 0f0f6a125b04c..b2d6af3bf53e0 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -1,5 +1,6 @@ #include "taichi/ir/frontend_ir.h" +#include "taichi/ir/expression_printer.h" #include "taichi/ir/statements.h" #include "taichi/program/program.h" #include "taichi/common/exceptions.h" @@ -8,7 +9,8 @@ TLANG_NAMESPACE_BEGIN #define TI_ASSERT_TYPE_CHECKED(x) \ TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ - "[{}] was not type-checked", x.serialize()) + "[{}] was not type-checked", \ + ExpressionHumanFriendlyPrinter::expr_to_string(x)) FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, SNode *snode, @@ -136,19 +138,6 @@ void RandExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void UnaryOpExpression::serialize(std::ostream &ss) { - ss << '('; - if (is_cast()) { - ss << (type == UnaryOpType::cast_value ? "" : "reinterpret_"); - ss << unary_op_type_name(type); - ss << '<' << data_type_name(cast_type) << "> "; - } else { - ss << unary_op_type_name(type) << ' '; - } - operand->serialize(ss); - ss << ')'; -} - void UnaryOpExpression::type_check(CompileConfig *) { TI_ASSERT_TYPE_CHECKED(operand); if (!operand->ret_type->is()) @@ -307,21 +296,6 @@ void GlobalPtrExpression::type_check(CompileConfig *) { } } -void GlobalPtrExpression::serialize(std::ostream &ss) { - if (snode) { - ss << snode->get_node_type_name_hinted(); - } else { - var.serialize(ss); - } - ss << '['; - for (int i = 0; i < (int)indices.size(); i++) { - indices.exprs[i]->serialize(ss); - if (i + 1 < (int)indices.size()) - ss << ", "; - } - ss << ']'; -} - void GlobalPtrExpression::flatten(FlattenContext *ctx) { std::vector index_stmts; std::vector offsets; @@ -430,21 +404,6 @@ void LoopUniqueExpression::type_check(CompileConfig *) { ret_type = input->ret_type; } -void LoopUniqueExpression::serialize(std::ostream &ss) { - ss << "loop_unique("; - input.serialize(ss); - for (int i = 0; i < covers.size(); i++) { - if (i == 0) - ss << ", covers=["; - ss << covers[i]->get_node_type_name_hinted(); - if (i == (int)covers.size() - 1) - ss << ']'; - else - ss << ", "; - } - ss << ')'; -} - void LoopUniqueExpression::flatten(FlattenContext *ctx) { flatten_rvalue(input, ctx); ctx->push_back(Stmt::make(input->stmt, covers)); @@ -477,31 +436,6 @@ void AtomicOpExpression::type_check(CompileConfig *) { } } -void AtomicOpExpression::serialize(std::ostream &ss) { - if (op_type == AtomicOpType::add) { - ss << "atomic_add("; - } else if (op_type == AtomicOpType::sub) { - ss << "atomic_sub("; - } else if (op_type == AtomicOpType::min) { - ss << "atomic_min("; - } else if (op_type == AtomicOpType::max) { - ss << "atomic_max("; - } else if (op_type == AtomicOpType::bit_and) { - ss << "atomic_bit_and("; - } else if (op_type == AtomicOpType::bit_or) { - ss << "atomic_bit_or("; - } else if (op_type == AtomicOpType::bit_xor) { - ss << "atomic_bit_xor("; - } else { - // min/max not supported in the LLVM backend yet. - TI_NOT_IMPLEMENTED; - } - dest.serialize(ss); - ss << ", "; - val.serialize(ss); - ss << ")"; -} - void AtomicOpExpression::flatten(FlattenContext *ctx) { // replace atomic sub with negative atomic add if (op_type == AtomicOpType::sub) { @@ -532,19 +466,6 @@ void SNodeOpExpression::type_check(CompileConfig *) { } } -void SNodeOpExpression::serialize(std::ostream &ss) { - ss << snode_op_type_name(op_type); - ss << '('; - ss << snode->get_node_type_name_hinted() << ", ["; - indices.serialize(ss); - ss << "]"; - if (value.expr) { - ss << ' '; - value.serialize(ss); - } - ss << ')'; -} - void SNodeOpExpression::flatten(FlattenContext *ctx) { std::vector indices_stmt; for (int i = 0; i < (int)indices.size(); i++) { @@ -590,7 +511,7 @@ void ConstExpression::flatten(FlattenContext *ctx) { void ExternalTensorShapeAlongAxisExpression::type_check(CompileConfig *) { TI_ASSERT_INFO(ptr.is(), "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression", - ptr.serialize()); + ExpressionHumanFriendlyPrinter::expr_to_string(ptr)); ret_type = PrimitiveType::i32; } @@ -623,12 +544,6 @@ void FuncCallExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } -void FuncCallExpression::serialize(std::ostream &ss) { - ss << "func_call(\"" << func->func_key.get_full_name() << "\", "; - args.serialize(ss); - ss << ')'; -} - // Mesh related. void MeshPatchIndexExpression::flatten(FlattenContext *ctx) { @@ -697,7 +612,8 @@ void ASTBuilder::insert_assignment(Expr &lhs, const Expr &rhs) { } else if (lhs.expr->is_lvalue()) { this->insert(std::make_unique(lhs, rhs)); } else { - TI_ERROR("Cannot assign to non-lvalue: {}", lhs.serialize()); + TI_ERROR("Cannot assign to non-lvalue: {}", + ExpressionHumanFriendlyPrinter::expr_to_string(lhs)); } } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 739f7ed773e89..2e6ef5bf1c9b8 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -282,11 +282,9 @@ class ArgLoadExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << fmt::format("arg[{}] (dt={})", arg_id, data_type_name(dt)); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class RandExpression : public Expression { @@ -298,11 +296,9 @@ class RandExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << fmt::format("rand<{}>()", data_type_name(dt)); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class UnaryOpExpression : public Expression { @@ -324,9 +320,9 @@ class UnaryOpExpression : public Expression { bool is_cast() const; - void serialize(std::ostream &ss) override; - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class BinaryOpExpression : public Expression { @@ -340,17 +336,9 @@ class BinaryOpExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << '('; - lhs->serialize(ss); - ss << ' '; - ss << binary_op_type_symbol(type); - ss << ' '; - rhs->serialize(ss); - ss << ')'; - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class TernaryOpExpression : public Expression { @@ -370,17 +358,9 @@ class TernaryOpExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << ternary_type_name(type) << '('; - op1->serialize(ss); - ss << ' '; - op2->serialize(ss); - ss << ' '; - op3->serialize(ss); - ss << ')'; - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class InternalFuncCallExpression : public Expression { @@ -400,22 +380,9 @@ class InternalFuncCallExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << "internal call " << func_name << '('; - if (with_runtime_context) { - ss << "runtime, "; - } - std::string args_str; - for (int i = 0; i < args.size(); i++) { - if (i != 0) { - ss << ", "; - } - args[i]->serialize(ss); - } - ss << ')'; - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; // TODO: Make this a non-expr @@ -450,12 +417,9 @@ class ExternalTensorExpression : public Expression { void type_check(CompileConfig *config) override { } - void serialize(std::ostream &ss) override { - ss << fmt::format("{}d_ext_arr (element_dim={}, dt={})", dim, element_dim, - dt->to_string()); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; // TODO: Make this a non-expr @@ -486,15 +450,9 @@ class GlobalVariableExpression : public Expression { set_attribute("dim", std::to_string(snode->num_active_indices)); } - void serialize(std::ostream &ss) override { - ss << "#" << ident.name(); - if (snode) - ss << fmt::format(" (snode={})", snode->get_node_type_name_hinted()); - else - ss << fmt::format(" (dt={})", dt->to_string()); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class GlobalPtrExpression : public Expression { @@ -513,13 +471,13 @@ class GlobalPtrExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override; - void flatten(FlattenContext *ctx) override; bool is_lvalue() const override { return true; } + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class TensorElementExpression : public Expression { @@ -543,29 +501,13 @@ class TensorElementExpression : public Expression { bool is_global_tensor() const; - void serialize(std::ostream &ss) override { - var.serialize(ss); - ss << '['; - for (int i = 0; i < (int)indices.size(); i++) { - indices.exprs[i]->serialize(ss); - if (i + 1 < (int)indices.size()) - ss << ", "; - } - ss << "] ("; - for (int i = 0; i < (int)shape.size(); i++) { - ss << std::to_string(shape[i]); - if (i + 1 < (int)shape.size()) - ss << ", "; - } - ss << ", stride = " + std::to_string(stride); - ss << ')'; - } - void flatten(FlattenContext *ctx) override; bool is_lvalue() const override { return true; } + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class RangeAssumptionExpression : public Expression { @@ -582,18 +524,9 @@ class RangeAssumptionExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << "assume_in_range({"; - base.serialize(ss); - ss << fmt::format("{:+d}", low); - ss << " <= ("; - input.serialize(ss); - ss << ") < "; - base.serialize(ss); - ss << fmt::format("{:+d})", high); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class LoopUniqueExpression : public Expression { @@ -607,9 +540,9 @@ class LoopUniqueExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override; - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class IdExpression : public Expression { @@ -622,10 +555,6 @@ class IdExpression : public Expression { void type_check(CompileConfig *config) override { } - void serialize(std::ostream &ss) override { - ss << id.name(); - } - void flatten(FlattenContext *ctx) override; Stmt *flatten_noload(FlattenContext *ctx) { @@ -635,6 +564,8 @@ class IdExpression : public Expression { bool is_lvalue() const override { return true; } + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; // ti.atomic_*() is an expression with side effect. @@ -649,9 +580,9 @@ class AtomicOpExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override; - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class SNodeOpExpression : public Expression { @@ -674,9 +605,9 @@ class SNodeOpExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override; - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class ConstExpression : public Expression { @@ -694,11 +625,9 @@ class ConstExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << val.stringify(); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class ExternalTensorShapeAlongAxisExpression : public Expression { @@ -706,12 +635,6 @@ class ExternalTensorShapeAlongAxisExpression : public Expression { Expr ptr; int axis; - void serialize(std::ostream &ss) override { - ss << "external_tensor_shape_along_axis("; - ptr->serialize(ss); - ss << ", " << axis << ')'; - } - ExternalTensorShapeAlongAxisExpression(const Expr &ptr, int axis) : ptr(ptr), axis(axis) { } @@ -719,6 +642,8 @@ class ExternalTensorShapeAlongAxisExpression : public Expression { void type_check(CompileConfig *config) override; void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class FuncCallExpression : public Expression { @@ -728,13 +653,13 @@ class FuncCallExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override; - FuncCallExpression(Function *func, const ExprGroup &args) : func(func), args(args) { } void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; // Mesh related. @@ -746,11 +671,9 @@ class MeshPatchIndexExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << fmt::format("mesh_patch_idx()"); - } - void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class MeshRelationAccessExpression : public Expression { @@ -762,20 +685,6 @@ class MeshRelationAccessExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - if (neighbor_idx) { - ss << "mesh_relation_access("; - mesh_idx->serialize(ss); - ss << ", " << mesh::element_type_name(to_type) << "["; - neighbor_idx->serialize(ss); - ss << "])"; - } else { - ss << "mesh_relation_size("; - mesh_idx->serialize(ss); - ss << ", " << mesh::element_type_name(to_type) << ")"; - } - } - MeshRelationAccessExpression(mesh::Mesh *mesh, const Expr mesh_idx, mesh::MeshElementType to_type) @@ -793,6 +702,8 @@ class MeshRelationAccessExpression : public Expression { } void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class MeshIndexConversionExpression : public Expression { @@ -804,13 +715,6 @@ class MeshIndexConversionExpression : public Expression { void type_check(CompileConfig *config) override; - void serialize(std::ostream &ss) override { - ss << "mesh_index_conversion(" << mesh::conv_type_name(conv_type) << ", " - << mesh::element_type_name(idx_type) << ", "; - idx->serialize(ss); - ss << ")"; - } - MeshIndexConversionExpression(mesh::Mesh *mesh, mesh::MeshElementType idx_type, const Expr idx, @@ -819,6 +723,8 @@ class MeshIndexConversionExpression : public Expression { } void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION }; class ASTBuilder { diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 438522b5ac366..58e81c82ad534 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -185,6 +185,7 @@ class IRVisitor { #include "taichi/inc/statements.inc.h" #undef PER_STATEMENT +#undef DEFINE_VISIT }; struct CompileConfig; diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index c0c0bcd6972ec..0b25f2248ade5 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -547,8 +547,7 @@ void export_lang(py::module &m) { py::return_value_policy::reference); py::class_ expr(m, "Expr"); - expr.def("serialize", [](Expr *expr) { return expr->serialize(); }) - .def("snode", &Expr::snode, py::return_value_policy::reference) + expr.def("snode", &Expr::snode, py::return_value_policy::reference) .def("is_global_var", [](Expr *expr) { return expr->is(); }) .def("is_external_var", @@ -592,8 +591,7 @@ void export_lang(py::module &m) { py::class_(m, "ExprGroup") .def(py::init<>()) .def("size", [](ExprGroup *eg) { return eg->exprs.size(); }) - .def("push_back", &ExprGroup::push_back) - .def("serialize", [](ExprGroup *eg) { eg->serialize(); }); + .def("push_back", &ExprGroup::push_back); py::class_(m, "Stmt"); diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 192e115029047..265068051647b 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -1,5 +1,6 @@ // The IRPrinter prints the IR in a human-readable format +#include "taichi/ir/expression_printer.h" #include "taichi/ir/ir.h" #include "taichi/ir/statements.h" #include "taichi/ir/transforms.h" @@ -45,6 +46,9 @@ std::string to_string(const LaneAttribute &ptr) { } class IRPrinter : public IRVisitor { + private: + ExpressionHumanFriendlyPrinter expr_printer_; + public: int current_indent; @@ -96,7 +100,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendExprStmt *stmt) override { - print("{}", stmt->val.serialize()); + print("{}", (stmt->val)); } void visit(FrontendBreakStmt *stmt) override { @@ -108,7 +112,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendAssignStmt *assign) override { - print("{} = {}", assign->lhs.serialize(), assign->rhs.serialize()); + print("{} = {}", expr_to_string(assign->lhs), expr_to_string(assign->rhs)); } void visit(FrontendAllocaStmt *alloca) override { @@ -117,7 +121,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendAssertStmt *assert) override { - print("{} : assert {}", assert->name(), assert->cond.serialize()); + print("{} : assert {}", assert->name(), expr_to_string(assert->cond)); } void visit(AssertStmt *assert) override { @@ -155,13 +159,13 @@ class IRPrinter : public IRVisitor { void visit(FrontendSNodeOpStmt *stmt) override { std::string extras = "["; for (int i = 0; i < (int)stmt->indices.size(); i++) { - extras += stmt->indices[i].serialize(); + extras += expr_to_string(stmt->indices[i]); if (i + 1 < (int)stmt->indices.size()) extras += ", "; } extras += "]"; if (stmt->val.expr) { - extras += ", " + stmt->val.serialize(); + extras += ", " + expr_to_string(stmt->val); } print("{} : {} {} {}", stmt->name(), snode_op_type_name(stmt->op_type), stmt->snode->get_node_type_name_hinted(), extras); @@ -242,7 +246,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendIfStmt *if_stmt) override { - print("{} : if {} {{", if_stmt->name(), if_stmt->condition.serialize()); + print("{} : if {} {{", if_stmt->name(), expr_to_string(if_stmt->condition)); if (if_stmt->true_statements) if_stmt->true_statements->accept(this); if (if_stmt->false_statements) { @@ -257,7 +261,7 @@ class IRPrinter : public IRVisitor { for (auto const &c : print_stmt->contents) { std::string name; if (std::holds_alternative(c)) - name = std::get(c).serialize(); + name = expr_to_string(std::get(c).expr.get()); else name = c_quoted(std::get(c)); contents.push_back(name); @@ -319,7 +323,7 @@ class IRPrinter : public IRVisitor { } void visit(FrontendWhileStmt *stmt) override { - print("{} : while {} {{", stmt->name(), stmt->cond.serialize()); + print("{} : while {} {{", stmt->name(), expr_to_string(stmt->cond)); stmt->body->accept(this); print("}}"); } @@ -330,7 +334,7 @@ class IRPrinter : public IRVisitor { [](const Identifier &id) -> std::string { return id.name(); }); if (for_stmt->is_ranged()) { print("{} : for {} in range({}, {}) {}{{", for_stmt->name(), vars, - for_stmt->begin.serialize(), for_stmt->end.serialize(), + expr_to_string(for_stmt->begin), expr_to_string(for_stmt->end), block_dim_info(for_stmt->block_dim)); } else if (for_stmt->mesh_for) { print("{} : for {} in mesh {{", for_stmt->name(), vars); @@ -339,7 +343,7 @@ class IRPrinter : public IRVisitor { for_stmt->global_var.is() ? for_stmt->global_var.cast() ->snode->get_node_type_name_hinted() - : for_stmt->global_var.serialize(), + : expr_to_string(for_stmt->global_var), scratch_pad_info(for_stmt->mem_access_opt), block_dim_info(for_stmt->block_dim)); } @@ -419,7 +423,7 @@ class IRPrinter : public IRVisitor { void visit(FrontendReturnStmt *stmt) override { print("{}{} : return [{}]", stmt->type_hint(), stmt->name(), - stmt->values.serialize()); + expr_group_to_string(stmt->values)); } void visit(ReturnStmt *stmt) override { @@ -758,14 +762,33 @@ class IRPrinter : public IRVisitor { } print(" (inputs="); for (auto &s : stmt->args) { - print(s.serialize()); + print(expr_to_string(s)); } print(", outputs="); for (auto &s : stmt->outputs) { - print(s.serialize()); + print(expr_to_string(s)); } print(")"); } + + private: + std::string expr_to_string(Expr &expr) { + return expr_to_string(expr.expr.get()); + } + + std::string expr_to_string(Expression *expr) { + std::ostringstream oss; + expr_printer_.set_ostream(&oss); + expr->accept(&expr_printer_); + return oss.str(); + } + + std::string expr_group_to_string(ExprGroup &expr_group) { + std::ostringstream oss; + expr_printer_.set_ostream(&oss); + expr_printer_.visit(expr_group); + return oss.str(); + } }; } // namespace diff --git a/tests/python/test_binding.py b/tests/python/test_binding.py index 6a7b197d29e4d..276954c7d4b74 100644 --- a/tests/python/test_binding.py +++ b/tests/python/test_binding.py @@ -9,5 +9,4 @@ def test_binding(): two = taichi_lang.make_const_expr_int(ti.i32, 2) expr = taichi_lang.make_binary_op_expr(taichi_lang.BinaryOpType.add, one, two) - print(expr.serialize()) print(taichi_lang.make_global_store_stmt(None, None))