diff --git a/taichi/analysis/check_fields_registered.cpp b/taichi/analysis/check_fields_registered.cpp index 0046627512bb3..6b417b0475eac 100644 --- a/taichi/analysis/check_fields_registered.cpp +++ b/taichi/analysis/check_fields_registered.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/count_statements.cpp b/taichi/analysis/count_statements.cpp index 138739eb68b7e..43796c3c501f7 100644 --- a/taichi/analysis/count_statements.cpp +++ b/taichi/analysis/count_statements.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/detect_fors_with_break.cpp b/taichi/analysis/detect_fors_with_break.cpp index 7a98dd2e54e1f..7deefe34aff81 100644 --- a/taichi/analysis/detect_fors_with_break.cpp +++ b/taichi/analysis/detect_fors_with_break.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend_ir.h" #include diff --git a/taichi/analysis/gather_deactivations.cpp b/taichi/analysis/gather_deactivations.cpp index bc3710fa3040b..365699dcad859 100644 --- a/taichi/analysis/gather_deactivations.cpp +++ b/taichi/analysis/gather_deactivations.cpp @@ -1,5 +1,6 @@ #include "taichi/ir/ir.h" -#include +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/gather_statements.cpp b/taichi/analysis/gather_statements.cpp index d32e72d94a729..3264cfea72181 100644 --- a/taichi/analysis/gather_statements.cpp +++ b/taichi/analysis/gather_statements.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/gather_used_atomics.cpp b/taichi/analysis/gather_used_atomics.cpp index 5c9a9a0922177..91b3809721fc8 100644 --- a/taichi/analysis/gather_used_atomics.cpp +++ b/taichi/analysis/gather_used_atomics.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/has_store_or_atomic.cpp b/taichi/analysis/has_store_or_atomic.cpp index 060d04a470085..9d7444d6f3936 100644 --- a/taichi/analysis/has_store_or_atomic.cpp +++ b/taichi/analysis/has_store_or_atomic.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/last_store_or_atomic.cpp b/taichi/analysis/last_store_or_atomic.cpp index 5a14318dfdb21..dde2fdd912e0b 100644 --- a/taichi/analysis/last_store_or_atomic.cpp +++ b/taichi/analysis/last_store_or_atomic.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 5ce0cbb33773a..ac31e55af6279 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include #include diff --git a/taichi/analysis/value_diff.cpp b/taichi/analysis/value_diff.cpp index 2231cd9dc56d4..0b09d8d30c2e9 100644 --- a/taichi/analysis/value_diff.cpp +++ b/taichi/analysis/value_diff.cpp @@ -1,6 +1,8 @@ // This pass analyzes compile-time known offsets for two values. #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/analysis/verify.cpp b/taichi/analysis/verify.cpp index 0805090378358..066147ca5e024 100644 --- a/taichi/analysis/verify.cpp +++ b/taichi/analysis/verify.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include #include diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index cc33ef61be9ff..829e8a2c84587 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -5,6 +5,7 @@ #include "taichi/backends/metal/constants.h" #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" #include "taichi/util/line_appender.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 0c24622ec5a0d..714af0ff1ad78 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -7,6 +7,7 @@ #include "taichi/backends/opengl/opengl_data_types.h" #include "taichi/backends/opengl/opengl_kernel_util.h" #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" #include "taichi/util/line_appender.h" #include "taichi/util/macros.h" diff --git a/taichi/codegen/codegen.cpp b/taichi/codegen/codegen.cpp index bdfa4d608c947..19e6532be50f6 100644 --- a/taichi/codegen/codegen.cpp +++ b/taichi/codegen/codegen.cpp @@ -8,7 +8,7 @@ #include "taichi/backends/cuda/codegen_cuda.h" #endif #include "taichi/system/timer.h" -#include "taichi/system/timer.h" +#include "taichi/ir/analysis.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h new file mode 100644 index 0000000000000..f94ed9b4ba9ad --- /dev/null +++ b/taichi/ir/analysis.h @@ -0,0 +1,70 @@ +#pragma once + +#include "taichi/ir/ir.h" +#include +#include +#include + +TLANG_NAMESPACE_BEGIN + +class DiffRange { + private: + bool related; + + public: + int coeff; + int low, high; + + DiffRange() : DiffRange(false, 0) { + } + + DiffRange(bool related, int coeff) : DiffRange(related, 0, 0) { + TI_ASSERT(related == false); + } + + DiffRange(bool related, int coeff, int low) + : DiffRange(related, coeff, low, low + 1) { + } + + DiffRange(bool related, int coeff, int low, int high) + : related(related), coeff(coeff), low(low), high(high) { + if (!related) { + this->low = this->high = 0; + } + } + + bool related_() const { + return related; + } + + bool linear_related() const { + return related && coeff == 1; + } + + bool certain() { + TI_ASSERT(related); + return high == low + 1; + } +}; + +// IR Analysis +namespace irpass::analysis { + +void check_fields_registered(IRNode *root); +int count_statements(IRNode *root); +std::unordered_set detect_fors_with_break(IRNode *root); +std::unordered_set detect_loops_with_continue(IRNode *root); +std::unordered_set gather_deactivations(IRNode *root); +std::vector gather_statements(IRNode *root, + const std::function &test); +std::unique_ptr> gather_used_atomics( + IRNode *root); +bool has_store_or_atomic(IRNode *root, const std::vector &vars); +std::pair last_store_or_atomic(IRNode *root, Stmt *var); +bool same_statements(IRNode *root1, IRNode *root2); +DiffRange value_diff(Stmt *stmt, int lane, Stmt *alloca); +void verify(IRNode *root); + +} // namespace irpass::analysis + +TLANG_NAMESPACE_END diff --git a/taichi/ir/expr.cpp b/taichi/ir/expr.cpp index abe4ba9a6f78b..ecd8b39f10d23 100644 --- a/taichi/ir/expr.cpp +++ b/taichi/ir/expr.cpp @@ -156,4 +156,51 @@ void Expr::operator/=(const Expr &o) { (*this) = (*this) / load_if_ptr(o); } +void Cache(int v, const Expr &var) { + dec.scratch_opt.push_back(std::make_pair(v, var.snode())); +} + +void CacheL1(const Expr &var) { + dec.scratch_opt.push_back(std::make_pair(1, var.snode())); +} + +Expr load_if_ptr(const Expr &ptr) { + if (ptr.is()) { + return load(ptr); + } else if (ptr.is()) { + TI_ASSERT(ptr.cast()->snode->num_active_indices == + 0); + return load(ptr[ExprGroup()]); + } else + return ptr; +} + +Expr load(const Expr &ptr) { + TI_ASSERT(ptr.is()); + return Expr::make(ptr); +} + +Expr ptr_if_global(const Expr &var) { + if (var.is()) { + // singleton global variable + TI_ASSERT(var.snode()->num_active_indices == 0); + return var[ExprGroup()]; + } else { + // may be any local or global expr + return var; + } +} + +Expr Var(const Expr &x) { + auto var = Expr(std::make_shared()); + current_ast_builder().insert(std::make_unique( + std::static_pointer_cast(var.expr)->id, DataType::unknown)); + var = x; + return var; +} + +void Print_(const Expr &a, const std::string &str) { + current_ast_builder().insert(std::make_unique(a, str)); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/expr.h b/taichi/ir/expr.h index cfab4599bf4d2..d9e41242d3392 100644 --- a/taichi/ir/expr.h +++ b/taichi/ir/expr.h @@ -123,4 +123,19 @@ Expr bit_cast(const Expr &input) { return taichi::lang::bit_cast(input, get_data_type()); } +Expr load_if_ptr(const Expr &ptr); +Expr load(const Expr &ptr); +Expr ptr_if_global(const Expr &var); + +inline Expr smart_load(const Expr &var) { + return load_if_ptr(ptr_if_global(var)); +} + +// Begin: legacy frontend functions +void Print_(const Expr &a, const std::string &str); +void Cache(int v, const Expr &var); +void CacheL1(const Expr &var); +Expr Var(const Expr &x); +// End: legacy frontend functions + TLANG_NAMESPACE_END diff --git a/taichi/ir/expression.h b/taichi/ir/expression.h index ea8201466d902..b7b1a187c7029 100644 --- a/taichi/ir/expression.h +++ b/taichi/ir/expression.h @@ -1,91 +1,112 @@ -// Arithmatic operations +#pragma once -#if defined(TI_EXPRESSION_IMPLEMENTATION) +#include "taichi/ir/expr.h" -#undef DEFINE_EXPRESSION_OP_BINARY -#undef DEFINE_EXPRESSION_OP_UNARY -#undef DEFINE_EXPRESSION_FUNC +TLANG_NAMESPACE_BEGIN -#define DEFINE_EXPRESSION_OP_UNARY(opname) \ - Expr opname(const Expr &expr) { \ - return Expr::make(UnaryOpType::opname, expr); \ - } \ - Expr expr_##opname(const Expr &expr) { \ - return opname(expr); \ +#include "taichi/ir/expression_ops.h" + +// always a tree - used as rvalues +class Expression { + public: + Stmt *stmt; + std::string tb; + std::map attributes; + + struct FlattenContext { + VecStatement stmts; + Block *current_block = nullptr; + + inline Stmt *push_back(pStmt &&stmt) { + return stmts.push_back(std::move(stmt)); + } + + template + T *push_back(Args &&... args) { + return stmts.push_back(std::forward(args)...); + } + + Stmt *back_stmt() { + return stmts.back().get(); + } + }; + + Expression() { + stmt = nullptr; } -#define DEFINE_EXPRESSION_OP_BINARY(op, opname) \ - Expr operator op(const Expr &lhs, const Expr &rhs) { \ - return Expr::make(BinaryOpType::opname, lhs, rhs); \ - } \ - Expr expr_##opname(const Expr &lhs, const Expr &rhs) { \ - return lhs op rhs; \ + virtual std::string serialize() = 0; + + virtual void flatten(FlattenContext *ctx) { + TI_NOT_IMPLEMENTED; + }; + + virtual bool is_lvalue() const { + return false; } -#define DEFINE_EXPRESSION_FUNC(opname) \ - Expr opname(const Expr &lhs, const Expr &rhs) { \ - return Expr::make(BinaryOpType::opname, lhs, rhs); \ - } \ - Expr expr_##opname(const Expr &lhs, const Expr &rhs) { \ - return opname(lhs, rhs); \ + virtual ~Expression() { } -#else - -#define DEFINE_EXPRESSION_OP_BINARY(op, opname) \ - Expr operator op(const Expr &lhs, const Expr &rhs); \ - Expr expr_##opname(const Expr &lhs, const Expr &rhs); - -#define DEFINE_EXPRESSION_OP_UNARY(opname) \ - Expr opname(const Expr &expr); \ - Expr expr_##opname(const Expr &expr); - -#define DEFINE_EXPRESSION_FUNC(opname) \ - Expr opname(const Expr &lhs, const Expr &rhs); \ - Expr expr_##opname(const Expr &lhs, const Expr &rhs); - -#endif - -DEFINE_EXPRESSION_OP_UNARY(sqrt) -DEFINE_EXPRESSION_OP_UNARY(floor) -DEFINE_EXPRESSION_OP_UNARY(ceil) -DEFINE_EXPRESSION_OP_UNARY(abs) -DEFINE_EXPRESSION_OP_UNARY(sin) -DEFINE_EXPRESSION_OP_UNARY(asin) -DEFINE_EXPRESSION_OP_UNARY(cos) -DEFINE_EXPRESSION_OP_UNARY(acos) -DEFINE_EXPRESSION_OP_UNARY(tan) -DEFINE_EXPRESSION_OP_UNARY(tanh) -DEFINE_EXPRESSION_OP_UNARY(inv) -DEFINE_EXPRESSION_OP_UNARY(rcp) -DEFINE_EXPRESSION_OP_UNARY(rsqrt) -DEFINE_EXPRESSION_OP_UNARY(exp) -DEFINE_EXPRESSION_OP_UNARY(log) - -DEFINE_EXPRESSION_OP_BINARY(+, add) -DEFINE_EXPRESSION_OP_BINARY(-, sub) -DEFINE_EXPRESSION_OP_BINARY(*, mul) -DEFINE_EXPRESSION_OP_BINARY(/, div) -DEFINE_EXPRESSION_OP_BINARY(%, mod) -DEFINE_EXPRESSION_OP_BINARY(&&, bit_and) -DEFINE_EXPRESSION_OP_BINARY(||, bit_or) -// DEFINE_EXPRESSION_OP_BINARY(&, bit_and) -// DEFINE_EXPRESSION_OP_BINARY(|, bit_or) -DEFINE_EXPRESSION_OP_BINARY (^, bit_xor) -DEFINE_EXPRESSION_OP_BINARY(<, cmp_lt) -DEFINE_EXPRESSION_OP_BINARY(<=, cmp_le) -DEFINE_EXPRESSION_OP_BINARY(>, cmp_gt) -DEFINE_EXPRESSION_OP_BINARY(>=, cmp_ge) -DEFINE_EXPRESSION_OP_BINARY(==, cmp_eq) -DEFINE_EXPRESSION_OP_BINARY(!=, cmp_ne) - -DEFINE_EXPRESSION_FUNC(min); -DEFINE_EXPRESSION_FUNC(max); -DEFINE_EXPRESSION_FUNC(atan2); -DEFINE_EXPRESSION_FUNC(pow); -DEFINE_EXPRESSION_FUNC(truediv); -DEFINE_EXPRESSION_FUNC(floordiv); - -#undef DEFINE_EXPRESSION_OP_UNARY -#undef DEFINE_EXPRESSION_OP_BINARY -#undef DEFINE_EXPRESSION_FUNC + void set_attribute(const std::string &key, const std::string &value) { + attributes[key] = value; + } + + std::string get_attribute(const std::string &key) const; +}; + +class ExprGroup { + public: + std::vector exprs; + + ExprGroup() { + } + + ExprGroup(const Expr &a) { + exprs.push_back(a); + } + + ExprGroup(const Expr &a, const Expr &b) { + exprs.push_back(a); + exprs.push_back(b); + } + + ExprGroup(const ExprGroup &a, const Expr &b) { + exprs = a.exprs; + exprs.push_back(b); + } + + ExprGroup(const Expr &a, const ExprGroup &b) { + exprs = b.exprs; + exprs.insert(exprs.begin(), a); + } + + void push_back(const Expr &expr) { + exprs.emplace_back(expr); + } + + std::size_t size() const { + return exprs.size(); + } + + const Expr &operator[](int i) const { + return exprs[i]; + } + + Expr &operator[](int i) { + return exprs[i]; + } + + std::string serialize() const; + ExprGroup loaded() const; +}; + +inline ExprGroup operator,(const Expr &a, const Expr &b) { + return ExprGroup(a, b); +} + +inline ExprGroup operator,(const ExprGroup &a, const Expr &b) { + return ExprGroup(a, b); +} + +TLANG_NAMESPACE_END diff --git a/taichi/ir/expression_ops.h b/taichi/ir/expression_ops.h new file mode 100644 index 0000000000000..ea8201466d902 --- /dev/null +++ b/taichi/ir/expression_ops.h @@ -0,0 +1,91 @@ +// Arithmatic operations + +#if defined(TI_EXPRESSION_IMPLEMENTATION) + +#undef DEFINE_EXPRESSION_OP_BINARY +#undef DEFINE_EXPRESSION_OP_UNARY +#undef DEFINE_EXPRESSION_FUNC + +#define DEFINE_EXPRESSION_OP_UNARY(opname) \ + Expr opname(const Expr &expr) { \ + return Expr::make(UnaryOpType::opname, expr); \ + } \ + Expr expr_##opname(const Expr &expr) { \ + return opname(expr); \ + } + +#define DEFINE_EXPRESSION_OP_BINARY(op, opname) \ + Expr operator op(const Expr &lhs, const Expr &rhs) { \ + return Expr::make(BinaryOpType::opname, lhs, rhs); \ + } \ + Expr expr_##opname(const Expr &lhs, const Expr &rhs) { \ + return lhs op rhs; \ + } + +#define DEFINE_EXPRESSION_FUNC(opname) \ + Expr opname(const Expr &lhs, const Expr &rhs) { \ + return Expr::make(BinaryOpType::opname, lhs, rhs); \ + } \ + Expr expr_##opname(const Expr &lhs, const Expr &rhs) { \ + return opname(lhs, rhs); \ + } + +#else + +#define DEFINE_EXPRESSION_OP_BINARY(op, opname) \ + Expr operator op(const Expr &lhs, const Expr &rhs); \ + Expr expr_##opname(const Expr &lhs, const Expr &rhs); + +#define DEFINE_EXPRESSION_OP_UNARY(opname) \ + Expr opname(const Expr &expr); \ + Expr expr_##opname(const Expr &expr); + +#define DEFINE_EXPRESSION_FUNC(opname) \ + Expr opname(const Expr &lhs, const Expr &rhs); \ + Expr expr_##opname(const Expr &lhs, const Expr &rhs); + +#endif + +DEFINE_EXPRESSION_OP_UNARY(sqrt) +DEFINE_EXPRESSION_OP_UNARY(floor) +DEFINE_EXPRESSION_OP_UNARY(ceil) +DEFINE_EXPRESSION_OP_UNARY(abs) +DEFINE_EXPRESSION_OP_UNARY(sin) +DEFINE_EXPRESSION_OP_UNARY(asin) +DEFINE_EXPRESSION_OP_UNARY(cos) +DEFINE_EXPRESSION_OP_UNARY(acos) +DEFINE_EXPRESSION_OP_UNARY(tan) +DEFINE_EXPRESSION_OP_UNARY(tanh) +DEFINE_EXPRESSION_OP_UNARY(inv) +DEFINE_EXPRESSION_OP_UNARY(rcp) +DEFINE_EXPRESSION_OP_UNARY(rsqrt) +DEFINE_EXPRESSION_OP_UNARY(exp) +DEFINE_EXPRESSION_OP_UNARY(log) + +DEFINE_EXPRESSION_OP_BINARY(+, add) +DEFINE_EXPRESSION_OP_BINARY(-, sub) +DEFINE_EXPRESSION_OP_BINARY(*, mul) +DEFINE_EXPRESSION_OP_BINARY(/, div) +DEFINE_EXPRESSION_OP_BINARY(%, mod) +DEFINE_EXPRESSION_OP_BINARY(&&, bit_and) +DEFINE_EXPRESSION_OP_BINARY(||, bit_or) +// DEFINE_EXPRESSION_OP_BINARY(&, bit_and) +// DEFINE_EXPRESSION_OP_BINARY(|, bit_or) +DEFINE_EXPRESSION_OP_BINARY (^, bit_xor) +DEFINE_EXPRESSION_OP_BINARY(<, cmp_lt) +DEFINE_EXPRESSION_OP_BINARY(<=, cmp_le) +DEFINE_EXPRESSION_OP_BINARY(>, cmp_gt) +DEFINE_EXPRESSION_OP_BINARY(>=, cmp_ge) +DEFINE_EXPRESSION_OP_BINARY(==, cmp_eq) +DEFINE_EXPRESSION_OP_BINARY(!=, cmp_ne) + +DEFINE_EXPRESSION_FUNC(min); +DEFINE_EXPRESSION_FUNC(max); +DEFINE_EXPRESSION_FUNC(atan2); +DEFINE_EXPRESSION_FUNC(pow); +DEFINE_EXPRESSION_FUNC(truediv); +DEFINE_EXPRESSION_FUNC(floordiv); + +#undef DEFINE_EXPRESSION_OP_UNARY +#undef DEFINE_EXPRESSION_OP_BINARY +#undef DEFINE_EXPRESSION_FUNC diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 6fb702bfb1927..5f6788bdad886 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -5,7 +5,7 @@ #include "taichi/lang_util.h" #include "taichi/ir/ir.h" -#include "taichi/ir/expr.h" +#include "taichi/ir/expression.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/ir/ir.cpp b/taichi/ir/ir.cpp index 6024d7248866c..6f3e0b0cb26d1 100644 --- a/taichi/ir/ir.cpp +++ b/taichi/ir/ir.cpp @@ -1,6 +1,8 @@ // Intermediate representations #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" #include #include @@ -12,7 +14,7 @@ TLANG_NAMESPACE_BEGIN #define TI_EXPRESSION_IMPLEMENTATION -#include "expression.h" +#include "expression_ops.h" IRBuilder ¤t_ast_builder() { return context->builder(); @@ -165,33 +167,6 @@ class StatementTypeNameVisitor : public IRVisitor { #undef PER_STATEMENT }; -Expr load_if_ptr(const Expr &ptr) { - if (ptr.is()) { - return load(ptr); - } else if (ptr.is()) { - TI_ASSERT(ptr.cast()->snode->num_active_indices == - 0); - return load(ptr[ExprGroup()]); - } else - return ptr; -} - -Expr load(const Expr &ptr) { - TI_ASSERT(ptr.is()); - return Expr::make(ptr); -} - -Expr ptr_if_global(const Expr &var) { - if (var.is()) { - // singleton global variable - TI_ASSERT(var.snode()->num_active_indices == 0); - return var[ExprGroup()]; - } else { - // may be any local or global expr - return var; - } -} - int StmtFieldSNode::get_snode_id(SNode *snode) { if (snode == nullptr) return -1; @@ -559,18 +534,6 @@ IRNode *FrontendContext::root() { std::unique_ptr context; -Expr Var(const Expr &x) { - auto var = Expr(std::make_shared()); - current_ast_builder().insert(std::make_unique( - std::static_pointer_cast(var.expr)->id, DataType::unknown)); - var = x; - return var; -} - -void Print_(const Expr &a, const std::string &str) { - current_ast_builder().insert(std::make_unique(a, str)); -} - template <> std::string to_string(const LaneAttribute &ptr) { std::string ret = " ["; @@ -909,6 +872,41 @@ std::unique_ptr ConstStmt::copy() { return std::make_unique(val); } +StructForStmt::StructForStmt(std::vector loop_vars, + SNode *snode, + std::unique_ptr &&body, + int vectorize, + int parallelize, + int block_dim) + : loop_vars(loop_vars), + snode(snode), + body(std::move(body)), + vectorize(vectorize), + parallelize(parallelize), + block_dim(block_dim) { + TI_STMT_REG_FIELDS; +} + +RangeForStmt::RangeForStmt(Stmt *loop_var, + Stmt *begin, + Stmt *end, + std::unique_ptr &&body, + int vectorize, + int parallelize, + int block_dim, + bool strictly_serialized) + : loop_var(loop_var), + begin(begin), + end(end), + body(std::move(body)), + vectorize(vectorize), + parallelize(parallelize), + block_dim(block_dim), + strictly_serialized(strictly_serialized) { + reversed = false; + TI_STMT_REG_FIELDS; +} + OffloadedStmt::OffloadedStmt(OffloadedStmt::TaskType task_type) : OffloadedStmt(task_type, nullptr) { } @@ -974,4 +972,8 @@ bool ContinueStmt::as_return() const { return false; } +void Stmt::infer_type() { + irpass::typecheck(this); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir.h b/taichi/ir/ir.h index 85916350ba82b..117bd559448d6 100644 --- a/taichi/ir/ir.h +++ b/taichi/ir/ir.h @@ -9,140 +9,26 @@ #include "taichi/common/bit.h" #include "taichi/lang_util.h" #include "taichi/ir/snode.h" -#include "taichi/ir/expr.h" #include "taichi/program/compile_config.h" #include "taichi/llvm/llvm_fwd.h" #include "taichi/util/short_name.h" TLANG_NAMESPACE_BEGIN -class DiffRange { - private: - bool related; - - public: - int coeff; - int low, high; - - DiffRange() : DiffRange(false, 0) { - } - - DiffRange(bool related, int coeff) : DiffRange(related, 0, 0) { - TI_ASSERT(related == false); - } - - DiffRange(bool related, int coeff, int low) - : DiffRange(related, coeff, low, low + 1) { - } - - DiffRange(bool related, int coeff, int low, int high) - : related(related), coeff(coeff), low(low), high(high) { - if (!related) { - this->low = this->high = 0; - } - } - - bool related_() const { - return related; - } - - bool linear_related() const { - return related && coeff == 1; - } - - bool certain() { - TI_ASSERT(related); - return high == low + 1; - } -}; - class IRBuilder; class IRNode; class Block; class Stmt; using pStmt = std::unique_ptr; -class DiffRange; class SNode; -using ScratchPadOptions = std::vector>; -class Expression; -class Expr; -class ExprGroup; class ScratchPads; +using ScratchPadOptions = std::vector>; #define PER_STATEMENT(x) class x; #include "taichi/inc/statements.inc.h" #undef PER_STATEMENT -// IR passes -namespace irpass { - -struct OffloadedResult { - // Total size in bytes of the global temporary variables - std::size_t total_size; - // Offloaded local variables to its offset in the global tmps memory. - std::unordered_map local_to_global_offset; -}; - -void re_id(IRNode *root); -void flag_access(IRNode *root); -void die(IRNode *root); -void simplify(IRNode *root, Kernel *kernel = nullptr); -void alg_simp(IRNode *root, const CompileConfig &config); -void whole_kernel_cse(IRNode *root); -void variable_optimization(IRNode *root, bool after_lower_access); -void extract_constant(IRNode *root); -void full_simplify(IRNode *root, - const CompileConfig &config, - Kernel *kernel = nullptr); -void print(IRNode *root, std::string *output = nullptr); -void lower(IRNode *root); -void typecheck(IRNode *root, Kernel *kernel = nullptr); -void loop_vectorize(IRNode *root); -void slp_vectorize(IRNode *root); -void vector_split(IRNode *root, int max_width, bool serial_schedule); -void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt); -void check_out_of_bound(IRNode *root); -void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel = nullptr); -void make_adjoint(IRNode *root, bool use_stack = false); -void constant_fold(IRNode *root); -OffloadedResult offload(IRNode *root); -void fix_block_parents(IRNode *root); -void replace_statements_with(IRNode *root, - std::function filter, - std::function()> generator); -void demote_dense_struct_fors(IRNode *root); -void demote_atomics(IRNode *root); -void reverse_segments(IRNode *root); // for autograd -std::unique_ptr initialize_scratch_pad(StructForStmt *root); -void compile_to_offloads(IRNode *ir, - const CompileConfig &config, - bool vectorize, - bool grad, - bool ad_use_stack, - bool verbose, - bool lower_global_access = true); - -// Analysis -namespace analysis { -void check_fields_registered(IRNode *root); -int count_statements(IRNode *root); -std::unordered_set detect_fors_with_break(IRNode *root); -std::unordered_set detect_loops_with_continue(IRNode *root); -std::unordered_set gather_deactivations(IRNode *root); -std::vector gather_statements(IRNode *root, - const std::function &test); -std::unique_ptr> gather_used_atomics( - IRNode *root); -bool has_store_or_atomic(IRNode *root, const std::vector &vars); -std::pair last_store_or_atomic(IRNode *root, Stmt *var); -bool same_statements(IRNode *root1, IRNode *root2); -DiffRange value_diff(Stmt *stmt, int lane, Stmt *alloca); -void verify(IRNode *root); -} // namespace analysis - -} // namespace irpass - IRBuilder ¤t_ast_builder(); bool maybe_same_address(Stmt *var1, Stmt *var2); @@ -249,14 +135,6 @@ class IRBuilder { void stop_gradient(SNode *); }; -Expr load_if_ptr(const Expr &ptr); -Expr load(const Expr &ptr); -Expr ptr_if_global(const Expr &var); - -inline Expr smart_load(const Expr &var) { - return load_if_ptr(ptr_if_global(var)); -} - class Identifier { public: static int id_counter; @@ -712,9 +590,7 @@ class Stmt : public IRNode { return std::make_unique(std::forward(args)...); } - void infer_type() { - irpass::typecheck(this); - } + void infer_type(); void set_tb(const std::string &tb) { this->tb = tb; @@ -725,109 +601,6 @@ class Stmt : public IRNode { virtual ~Stmt() override = default; }; -// always a tree - used as rvalues -class Expression { - public: - Stmt *stmt; - std::string tb; - std::map attributes; - - struct FlattenContext { - VecStatement stmts; - Block *current_block = nullptr; - - inline Stmt *push_back(pStmt &&stmt) { - return stmts.push_back(std::move(stmt)); - } - - template - T *push_back(Args &&... args) { - return stmts.push_back(std::forward(args)...); - } - - Stmt *back_stmt() { - return stmts.back().get(); - } - }; - - Expression() { - stmt = nullptr; - } - - virtual std::string serialize() = 0; - - virtual void flatten(FlattenContext *ctx) { - TI_NOT_IMPLEMENTED; - }; - - virtual bool is_lvalue() const { - return false; - } - - virtual ~Expression() { - } - - void set_attribute(const std::string &key, const std::string &value) { - attributes[key] = value; - } - - std::string get_attribute(const std::string &key) const; -}; - -class ExprGroup { - public: - std::vector exprs; - - ExprGroup() { - } - - ExprGroup(const Expr &a) { - exprs.push_back(a); - } - - ExprGroup(const Expr &a, const Expr &b) { - exprs.push_back(a); - exprs.push_back(b); - } - - ExprGroup(const ExprGroup &a, const Expr &b) { - exprs = a.exprs; - exprs.push_back(b); - } - - ExprGroup(const Expr &a, const ExprGroup &b) { - exprs = b.exprs; - exprs.insert(exprs.begin(), a); - } - - void push_back(const Expr &expr) { - exprs.emplace_back(expr); - } - - std::size_t size() const { - return exprs.size(); - } - - const Expr &operator[](int i) const { - return exprs[i]; - } - - Expr &operator[](int i) { - return exprs[i]; - } - - std::string serialize() const; - ExprGroup loaded() const; -}; - -inline ExprGroup operator,(const Expr &a, const Expr &b) { - return ExprGroup(a, b); -} - -inline ExprGroup operator,(const ExprGroup &a, const Expr &b) { - return ExprGroup(a, b); -} - class AllocaStmt : public Stmt { public: AllocaStmt(DataType type) { @@ -1039,8 +812,6 @@ class GlobalPtrStmt : public Stmt { DEFINE_ACCEPT }; -#include "expression.h" - class Block : public IRNode { public: Block *parent; @@ -1319,18 +1090,7 @@ class RangeForStmt : public Stmt { int vectorize, int parallelize, int block_dim, - bool strictly_serialized) - : loop_var(loop_var), - begin(begin), - end(end), - body(std::move(body)), - vectorize(vectorize), - parallelize(parallelize), - block_dim(block_dim), - strictly_serialized(strictly_serialized) { - reversed = false; - TI_STMT_REG_FIELDS; - } + bool strictly_serialized); bool is_container_statement() const override { return true; @@ -1369,15 +1129,7 @@ class StructForStmt : public Stmt { std::unique_ptr &&body, int vectorize, int parallelize, - int block_dim) - : loop_vars(loop_vars), - snode(snode), - body(std::move(body)), - vectorize(vectorize), - parallelize(parallelize), - block_dim(block_dim) { - TI_STMT_REG_FIELDS; - } + int block_dim); bool is_container_statement() const override { return true; @@ -1455,8 +1207,6 @@ class WhileStmt : public Stmt { DEFINE_ACCEPT }; -void Print_(const Expr &a, const std::string &str); - extern DecoratorRecorder dec; inline void Vectorize(int v) { @@ -1471,14 +1221,6 @@ inline void StrictlySerialize() { dec.strictly_serialized = true; } -inline void Cache(int v, const Expr &var) { - dec.scratch_opt.push_back(std::make_pair(v, var.snode())); -} - -inline void CacheL1(const Expr &var) { - dec.scratch_opt.push_back(std::make_pair(1, var.snode())); -} - inline void BlockDim(int v) { TI_ASSERT(bit::is_power_of_two(v)); dec.block_dim = v; @@ -1488,8 +1230,6 @@ inline void SLP(int v) { current_ast_builder().insert(Stmt::make(v)); } -Expr Var(const Expr &x); - class VectorElement { public: Stmt *stmt; @@ -1532,6 +1272,3 @@ inline void StmtFieldManager::operator()(const char *key, T &&value) { } TLANG_NAMESPACE_END - -#include "taichi/ir/statements.h" -#include "taichi/ir/visitors.h" diff --git a/taichi/ir/state_machine.cpp b/taichi/ir/state_machine.cpp index c50337a6441d5..1181dd873a0f1 100644 --- a/taichi/ir/state_machine.cpp +++ b/taichi/ir/state_machine.cpp @@ -1,4 +1,6 @@ #include "state_machine.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/analysis.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h new file mode 100644 index 0000000000000..ffbf15d115bb1 --- /dev/null +++ b/taichi/ir/transforms.h @@ -0,0 +1,61 @@ +#pragma once + +#include "taichi/ir/ir.h" +#include +#include +#include + +TLANG_NAMESPACE_BEGIN + +// IR passes +namespace irpass { + +struct OffloadedResult { + // Total size in bytes of the global temporary variables + std::size_t total_size; + // Offloaded local variables to its offset in the global tmps memory. + std::unordered_map local_to_global_offset; +}; + +void re_id(IRNode *root); +void flag_access(IRNode *root); +void die(IRNode *root); +void simplify(IRNode *root, Kernel *kernel = nullptr); +void alg_simp(IRNode *root, const CompileConfig &config); +void whole_kernel_cse(IRNode *root); +void variable_optimization(IRNode *root, bool after_lower_access); +void extract_constant(IRNode *root); +void full_simplify(IRNode *root, + const CompileConfig &config, + Kernel *kernel = nullptr); +void print(IRNode *root, std::string *output = nullptr); +void lower(IRNode *root); +void typecheck(IRNode *root, Kernel *kernel = nullptr); +void loop_vectorize(IRNode *root); +void slp_vectorize(IRNode *root); +void vector_split(IRNode *root, int max_width, bool serial_schedule); +void replace_all_usages_with(IRNode *root, Stmt *old_stmt, Stmt *new_stmt); +void check_out_of_bound(IRNode *root); +void lower_access(IRNode *root, bool lower_atomic, Kernel *kernel = nullptr); +void make_adjoint(IRNode *root, bool use_stack = false); +void constant_fold(IRNode *root); +OffloadedResult offload(IRNode *root); +void fix_block_parents(IRNode *root); +void replace_statements_with(IRNode *root, + std::function filter, + std::function()> generator); +void demote_dense_struct_fors(IRNode *root); +void demote_atomics(IRNode *root); +void reverse_segments(IRNode *root); // for autograd +std::unique_ptr initialize_scratch_pad(StructForStmt *root); +void compile_to_offloads(IRNode *ir, + const CompileConfig &config, + bool vectorize, + bool grad, + bool ad_use_stack, + bool verbose, + bool lower_global_access = true); + +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/taichi/ir/visitors.h b/taichi/ir/visitors.h index 50e45972fb131..7627494f6bed6 100644 --- a/taichi/ir/visitors.h +++ b/taichi/ir/visitors.h @@ -1,5 +1,6 @@ #pragma once -#include "statements.h" +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 519f3ecc34a99..1ce322ced4781 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -7,6 +7,8 @@ #include "taichi/backends/cpu/codegen_cpu.h" #include "taichi/common/testing.h" #include "taichi/util/statistics.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/program/async_engine.h b/taichi/program/async_engine.h index 2d78b09b743ac..4349aa35713cf 100644 --- a/taichi/program/async_engine.h +++ b/taichi/program/async_engine.h @@ -5,6 +5,7 @@ #define TI_RUNTIME_HOST #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/runtime/llvm/context.h" #include "taichi/lang_util.h" diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 0ab80c7c9f3dc..ef25bff074980 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -5,25 +5,25 @@ #include "taichi/program/async_engine.h" #include "taichi/codegen/codegen.h" #include "taichi/backends/cuda/cuda_driver.h" +#include "taichi/ir/transforms.h" TLANG_NAMESPACE_BEGIN namespace { - class CurrentKernelGuard { - Kernel *old_kernel; - Program &program; - - public: - CurrentKernelGuard(Program &program_, Kernel *kernel) - : program(program_) { - old_kernel = program.current_kernel; - program.current_kernel = kernel; - } +class CurrentKernelGuard { + Kernel *old_kernel; + Program &program; + + public: + CurrentKernelGuard(Program &program_, Kernel *kernel) : program(program_) { + old_kernel = program.current_kernel; + program.current_kernel = kernel; + } - ~CurrentKernelGuard() { - program.current_kernel = old_kernel; - } - }; + ~CurrentKernelGuard() { + program.current_kernel = old_kernel; + } +}; } // namespace Kernel::Kernel(Program &program, diff --git a/taichi/struct/struct.cpp b/taichi/struct/struct.cpp index 8ad5eae002368..b054685b68074 100644 --- a/taichi/struct/struct.cpp +++ b/taichi/struct/struct.cpp @@ -1,6 +1,7 @@ // Codegen for the hierarchical data structure #include "taichi/ir/ir.h" +#include "taichi/ir/expression.h" #include "taichi/program/program.h" #include "struct.h" diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index bfbdb2f504d69..41a53aab9c52c 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/program/program.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/check_out_of_bound.cpp b/taichi/transforms/check_out_of_bound.cpp index b2400f17d494b..c4245a078e088 100644 --- a/taichi/transforms/check_out_of_bound.cpp +++ b/taichi/transforms/check_out_of_bound.cpp @@ -1,5 +1,7 @@ -#include #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index dacbd44b4db15..2f81cf803d641 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -1,4 +1,7 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index ed2842241462b..c9322a4a41f3d 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -1,6 +1,7 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/program/program.h" -#include "taichi/ir/snode.h" #include #include #include @@ -14,49 +15,45 @@ class ConstantFold : public BasicStmtVisitor { ConstantFold() : BasicStmtVisitor() { } - struct JITEvaluatorId - { + struct JITEvaluatorId { int op; DataType ret, lhs, rhs; bool is_binary; - explicit operator JITEvaluatorIdType() const - { + explicit operator JITEvaluatorIdType() const { // For a unique hash value, the number of UnaryOpTypes and BinaryOpTypes // should be no more than 256, and the number of DataTypes should be no // more than 128. - return (JITEvaluatorIdType)op | (JITEvaluatorIdType)ret << 8 - | (JITEvaluatorIdType)lhs << 16 | (JITEvaluatorIdType)rhs << 24 - | (JITEvaluatorIdType)is_binary << 31; + return (JITEvaluatorIdType)op | (JITEvaluatorIdType)ret << 8 | + (JITEvaluatorIdType)lhs << 16 | (JITEvaluatorIdType)rhs << 24 | + (JITEvaluatorIdType)is_binary << 31; } - UnaryOpType unary_op() const - { + UnaryOpType unary_op() const { TI_ASSERT(!is_binary); - return (UnaryOpType) op; + return (UnaryOpType)op; } - BinaryOpType binary_op() const - { + BinaryOpType binary_op() const { TI_ASSERT(is_binary); - return (BinaryOpType) op; + return (BinaryOpType)op; } }; - static Kernel *get_jit_evaluator_kernel(JITEvaluatorId const &id) - { + static Kernel *get_jit_evaluator_kernel(JITEvaluatorId const &id) { auto &cache = get_current_program().jit_evaluator_cache; auto hash_id = JITEvaluatorIdType(id); - auto it = cache.find(hash_id); // We need the hash value to be unique here. - if (it != cache.end()) // cached? + auto it = cache.find(hash_id); // We need the hash value to be unique here. + if (it != cache.end()) // cached? return it->second.get(); auto kernel_name = fmt::format("jit_evaluator_{}", cache.size()); - auto func = [&] () { + auto func = [&]() { auto lhstmt = Stmt::make(0, false); auto rhstmt = Stmt::make(1, false); pStmt oper; if (id.is_binary) { - oper = Stmt::make(id.binary_op(), lhstmt.get(), rhstmt.get()); + oper = Stmt::make(id.binary_op(), lhstmt.get(), + rhstmt.get()); } else { oper = Stmt::make(id.unary_op(), lhstmt.get()); if (unary_op_is_cast(id.unary_op())) { @@ -70,7 +67,8 @@ class ConstantFold : public BasicStmtVisitor { current_ast_builder().insert(std::move(oper)); current_ast_builder().insert(std::move(ret)); }; - auto ker = std::make_unique(get_current_program(), func, kernel_name); + auto ker = + std::make_unique(get_current_program(), func, kernel_name); ker->insert_ret(id.ret); ker->insert_arg(id.lhs, false); if (id.is_binary) @@ -82,18 +80,17 @@ class ConstantFold : public BasicStmtVisitor { return ker_ptr; } - static bool is_good_type(DataType dt) - { + static bool is_good_type(DataType dt) { // ConstStmt of `bad` types like `i8` is not supported by LLVM. // Dis: https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727 switch (dt) { - case DataType::i32: - case DataType::f32: - case DataType::i64: - case DataType::f64: - return true; - default: - return false; + case DataType::i32: + case DataType::f32: + case DataType::i64: + case DataType::f64: + return true; + default: + return false; } } @@ -111,16 +108,17 @@ class ConstantFold : public BasicStmtVisitor { } }; - static bool jit_evaluate_binary_op(TypedConstant &ret, BinaryOpStmt *stmt, - const TypedConstant &lhs, const TypedConstant &rhs) - { + static bool jit_evaluate_binary_op(TypedConstant &ret, + BinaryOpStmt *stmt, + const TypedConstant &lhs, + const TypedConstant &rhs) { if (!is_good_type(ret.dt)) return false; - JITEvaluatorId id{(int)stmt->op_type, ret.dt, lhs.dt, rhs.dt, - true}; + JITEvaluatorId id{(int)stmt->op_type, ret.dt, lhs.dt, rhs.dt, true}; auto *ker = get_jit_evaluator_kernel(id); auto &ctx = get_current_program().get_context(); - ContextArgSaveGuard _(ctx); // save input args, prevent override current kernel + ContextArgSaveGuard _( + ctx); // save input args, prevent override current kernel ctx.set_arg(0, lhs.val_i64); ctx.set_arg(1, rhs.val_i64); (*ker)(); @@ -128,16 +126,17 @@ class ConstantFold : public BasicStmtVisitor { return true; } - static bool jit_evaluate_unary_op(TypedConstant &ret, UnaryOpStmt *stmt, - const TypedConstant &operand) - { + static bool jit_evaluate_unary_op(TypedConstant &ret, + UnaryOpStmt *stmt, + const TypedConstant &operand) { if (!is_good_type(ret.dt)) return false; JITEvaluatorId id{(int)stmt->op_type, ret.dt, operand.dt, stmt->cast_type, - false}; + false}; auto *ker = get_jit_evaluator_kernel(id); auto &ctx = get_current_program().get_context(); - ContextArgSaveGuard _(ctx); // save input args, prevent override current kernel + ContextArgSaveGuard _( + ctx); // save input args, prevent override current kernel ctx.set_arg(0, operand.val_i64); (*ker)(); ret.val_i64 = get_current_program().fetch_result(0); diff --git a/taichi/transforms/demote_atomics.cpp b/taichi/transforms/demote_atomics.cpp index 6ee331f9c0727..c95b9b39bfbc8 100644 --- a/taichi/transforms/demote_atomics.cpp +++ b/taichi/transforms/demote_atomics.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include #include diff --git a/taichi/transforms/demote_dense_struct_fors.cpp b/taichi/transforms/demote_dense_struct_fors.cpp index 295e0038d62e7..721e8f4af0f92 100644 --- a/taichi/transforms/demote_dense_struct_fors.cpp +++ b/taichi/transforms/demote_dense_struct_fors.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/die.cpp b/taichi/transforms/die.cpp index 9519ce4dcc155..6561957777510 100644 --- a/taichi/transforms/die.cpp +++ b/taichi/transforms/die.cpp @@ -1,6 +1,8 @@ // Dead Instruction Elimination #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/extract_constant.cpp b/taichi/transforms/extract_constant.cpp index c6a1e511ddbed..08e9b8ff4b081 100644 --- a/taichi/transforms/extract_constant.cpp +++ b/taichi/transforms/extract_constant.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/fix_block_parents.cpp b/taichi/transforms/fix_block_parents.cpp index e32273cba60da..88d27747d5dba 100644 --- a/taichi/transforms/fix_block_parents.cpp +++ b/taichi/transforms/fix_block_parents.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index 02b8d34800c57..bc84d8beb5714 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/insert_scratch_pad.cpp b/taichi/transforms/insert_scratch_pad.cpp index f6f877ee58d50..64cc3bc4e3c75 100644 --- a/taichi/transforms/insert_scratch_pad.cpp +++ b/taichi/transforms/insert_scratch_pad.cpp @@ -1,4 +1,7 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/scratch_pad.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 617ce6ddb7f97..5843b1d8d5397 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -1,9 +1,10 @@ // The IRPrinter prints the IR in a human-readable format -#include - #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend_ir.h" +#include TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/loop_vectorize.cpp b/taichi/transforms/loop_vectorize.cpp index aa3acdb8e9cb9..b4ea591662e81 100644 --- a/taichi/transforms/loop_vectorize.cpp +++ b/taichi/transforms/loop_vectorize.cpp @@ -1,5 +1,8 @@ // The loop vectorizer + #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN @@ -154,4 +157,4 @@ void loop_vectorize(IRNode *root) { } // namespace irpass -TLANG_NAMESPACE_END \ No newline at end of file +TLANG_NAMESPACE_END diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index e74f0b6617c42..4cc1b9d990869 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -1,4 +1,7 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include #include diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index 8e89082a66ff4..3bec42df78101 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -1,7 +1,9 @@ -#include - #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend_ir.h" +#include TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/make_adjoint.cpp b/taichi/transforms/make_adjoint.cpp index f62555d56ecce..c2c76ed60927a 100644 --- a/taichi/transforms/make_adjoint.cpp +++ b/taichi/transforms/make_adjoint.cpp @@ -1,6 +1,9 @@ -#include #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend.h" +#include TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index fbd40bdc8e885..73f7b85aafa4d 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -1,9 +1,11 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include #include #include -#include "taichi/ir/ir.h" - TLANG_NAMESPACE_BEGIN namespace irpass { diff --git a/taichi/transforms/re_id.cpp b/taichi/transforms/re_id.cpp index 31a0e77121e7c..0e2fdf63a318c 100644 --- a/taichi/transforms/re_id.cpp +++ b/taichi/transforms/re_id.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend_ir.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/reverse_segments.cpp b/taichi/transforms/reverse_segments.cpp index c379bf92ab022..c5a153215fc24 100644 --- a/taichi/transforms/reverse_segments.cpp +++ b/taichi/transforms/reverse_segments.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/frontend_ir.h" #include diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index bd29dc71798d0..ad6a3c14e5f9d 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -1,7 +1,10 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include #include #include -#include "taichi/ir/ir.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/slp_vectorize.cpp b/taichi/transforms/slp_vectorize.cpp index e66a86444332f..5648047afcadc 100644 --- a/taichi/transforms/slp_vectorize.cpp +++ b/taichi/transforms/slp_vectorize.cpp @@ -1,8 +1,10 @@ // Superword level vectorization +#include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include #include -#include "taichi/ir/ir.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/statement_replace.cpp b/taichi/transforms/statement_replace.cpp index 635423276c177..e8a00b7dd6ff4 100644 --- a/taichi/transforms/statement_replace.cpp +++ b/taichi/transforms/statement_replace.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/statement_usage_replace.cpp b/taichi/transforms/statement_usage_replace.cpp index 185d7126a08b2..7338fe1e81ce5 100644 --- a/taichi/transforms/statement_usage_replace.cpp +++ b/taichi/transforms/statement_usage_replace.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 40b3a7996500d..c0f6d4efe5318 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -1,6 +1,9 @@ // Type checking #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" #include "taichi/program/kernel.h" #include "taichi/ir/frontend.h" diff --git a/taichi/transforms/variable_optimization.cpp b/taichi/transforms/variable_optimization.cpp index 85e2cfc1def55..0d754370d0d34 100644 --- a/taichi/transforms/variable_optimization.cpp +++ b/taichi/transforms/variable_optimization.cpp @@ -1,4 +1,6 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" #include "taichi/ir/state_machine.h" #include diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp index 57596c96ecddd..cd12ab8734e25 100644 --- a/taichi/transforms/vector_split.cpp +++ b/taichi/transforms/vector_split.cpp @@ -1,6 +1,8 @@ // Split vectors wider than machine vector width into multiple vectors #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/taichi/transforms/whole_kernel_cse.cpp b/taichi/transforms/whole_kernel_cse.cpp index 9d74b4ae38ba8..e156eabb7692f 100644 --- a/taichi/transforms/whole_kernel_cse.cpp +++ b/taichi/transforms/whole_kernel_cse.cpp @@ -1,4 +1,7 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_alg_simp.cpp b/tests/cpp/test_alg_simp.cpp index 6a2b400705f05..e738b862ecb1d 100644 --- a/tests/cpp/test_alg_simp.cpp +++ b/tests/cpp/test_alg_simp.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/frontend.h" +#include "taichi/ir/transforms.h" #include "taichi/common/testing.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_exception_handling.cpp b/tests/cpp/test_exception_handling.cpp index 2e877b61c56d4..bc185429628e7 100644 --- a/tests/cpp/test_exception_handling.cpp +++ b/tests/cpp/test_exception_handling.cpp @@ -1,5 +1,6 @@ #include "taichi/common/task.h" #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_same_statements.cpp b/tests/cpp/test_same_statements.cpp index d1615883616f2..f236a504955c8 100644 --- a/tests/cpp/test_same_statements.cpp +++ b/tests/cpp/test_same_statements.cpp @@ -1,5 +1,7 @@ -#include -#include +#include "taichi/ir/frontend.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/analysis.h" +#include "taichi/common/testing.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_simplify.cpp b/tests/cpp/test_simplify.cpp index f7ec144165e6b..ac6a0e74b4891 100644 --- a/tests/cpp/test_simplify.cpp +++ b/tests/cpp/test_simplify.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/frontend.h" +#include "taichi/ir/transforms.h" #include "taichi/common/testing.h" TLANG_NAMESPACE_BEGIN diff --git a/tests/cpp/test_stmt_field_manager.cpp b/tests/cpp/test_stmt_field_manager.cpp index a9f5d07b96832..3db2834d3e84c 100644 --- a/tests/cpp/test_stmt_field_manager.cpp +++ b/tests/cpp/test_stmt_field_manager.cpp @@ -1,4 +1,5 @@ #include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" #include "taichi/common/testing.h" TLANG_NAMESPACE_BEGIN