diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index e2a336c6b6252..678076650e51a 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -1,8 +1,6 @@ -#include -#include "taichi/analysis/offline_cache_util.h" -#include "taichi/common/logging.h" +#include "offline_cache_util.h" + #include "taichi/ir/expr.h" -#include "taichi/ir/expression_printer.h" #include "taichi/ir/frontend_ir.h" #include "taichi/ir/ir.h" #include "taichi/ir/mesh.h" @@ -15,7 +13,15 @@ namespace lang { namespace { +enum class ExprOpCode : std::uint8_t { + NIL, +#define PER_EXPRESSION(x) x, +#include "taichi/inc/expressions.inc.h" +#undef PER_EXPRESSION +}; + enum class StmtOpCode : std::uint8_t { + NIL, EnterBlock, ExitBlock, #define PER_STATEMENT(x) x, @@ -35,14 +41,19 @@ enum class ExternalFuncType : std::uint8_t { BC, }; -class ASTSerializer : public IRVisitor { +enum class MeshRelationAccessType { + Access, // mesh_relation_access + Size, // mesh_relation_size +}; + +class ASTSerializer : public IRVisitor, public ExpressionVisitor { + private: + using ExpressionVisitor::visit; + using IRVisitor::visit; + public: - ASTSerializer(Program *prog, - ExpressionPrinter *expr_printer, - std::ostream *os) - : prog_(prog), os_(os), expr_printer_(expr_printer) { + ASTSerializer(Program *prog, std::ostream *os) : prog_(prog), os_(os) { this->allow_undefined_visitor = true; - expr_printer_->set_ostream(os); } void set_ostream(std::ostream *os) { @@ -53,6 +64,178 @@ class ASTSerializer : public IRVisitor { return this->os_; } + void visit(Expression *expr) override { + expr->accept(this); + } + + void visit(Stmt *stmt) override { + stmt->accept(this); + } + + void visit(ExprGroup &expr_group) override { + emit(expr_group.exprs); + } + + void visit(ArgLoadExpression *expr) override { + emit(ExprOpCode::ArgLoadExpression); + emit(expr->dt); + emit(expr->arg_id); + } + + void visit(RandExpression *expr) override { + emit(ExprOpCode::RandExpression); + emit(expr->dt); + } + + void visit(UnaryOpExpression *expr) override { + emit(ExprOpCode::UnaryOpExpression); + emit(expr->type); + if (expr->is_cast()) { + emit(expr->cast_type); + } + emit(expr->operand); + } + + void visit(BinaryOpExpression *expr) override { + emit(ExprOpCode::BinaryOpExpression); + emit(expr->type); + emit(expr->lhs); + emit(expr->rhs); + } + + void visit(TernaryOpExpression *expr) override { + emit(ExprOpCode::TernaryOpExpression); + emit(expr->type); + emit(expr->op1); + emit(expr->op2); + emit(expr->op3); + } + + void visit(InternalFuncCallExpression *expr) override { + emit(ExprOpCode::InternalFuncCallExpression); + emit(expr->with_runtime_context); + emit(expr->func_name); + emit(expr->args); + } + + void visit(ExternalTensorExpression *expr) override { + emit(ExprOpCode::ExternalTensorExpression); + emit(expr->dt); + emit(expr->dim); + emit(expr->arg_id); + emit(expr->element_dim); + emit(expr->element_shape); + } + + void visit(GlobalVariableExpression *expr) override { + emit(ExprOpCode::GlobalVariableExpression); + emit(expr->ident); + emit(expr->dt); + emit(expr->snode); + emit(expr->has_ambient); + emit(expr->ambient_value); + emit(expr->is_primal); + emit(expr->adjoint); + } + + void visit(GlobalPtrExpression *expr) override { + emit(ExprOpCode::GlobalPtrExpression); + emit(expr->var); + emit(expr->indices.exprs); + } + + void visit(TensorElementExpression *expr) override { + emit(ExprOpCode::TensorElementExpression); + emit(expr->var); + emit(expr->indices.exprs); + emit(expr->shape); + emit(expr->stride); + } + + void visit(RangeAssumptionExpression *expr) override { + emit(ExprOpCode::RangeAssumptionExpression); + emit(expr->input); + emit(expr->base); + emit(expr->low); + emit(expr->high); + } + + void visit(LoopUniqueExpression *expr) override { + emit(ExprOpCode::LoopUniqueExpression); + emit(expr->input); + emit(expr->covers); + } + + void visit(IdExpression *expr) override { + emit(ExprOpCode::IdExpression); + emit(expr->id); + } + + void visit(AtomicOpExpression *expr) override { + emit(ExprOpCode::AtomicOpExpression); + emit(expr->op_type); + emit(expr->dest); + emit(expr->val); + } + + void visit(SNodeOpExpression *expr) override { + emit(ExprOpCode::SNodeOpExpression); + emit(expr->op_type); + emit(expr->snode); + std::size_t count = expr->indices.size(); + if (expr->value.expr) + ++count; + emit(count); + for (const auto &i : expr->indices.exprs) { + emit(i); + } + if (expr->value.expr) { + emit(expr->value); + } + } + + void visit(ConstExpression *expr) override { + emit(ExprOpCode::ConstExpression); + emit(expr->val); + } + + void visit(ExternalTensorShapeAlongAxisExpression *expr) override { + emit(ExprOpCode::ExternalTensorShapeAlongAxisExpression); + emit(expr->ptr); + emit(expr->axis); + } + + void visit(FuncCallExpression *expr) override { + emit(ExprOpCode::FuncCallExpression); + emit(expr->func); + emit(expr->args.exprs); + } + + void visit(MeshPatchIndexExpression *expr) override { + emit(ExprOpCode::MeshPatchIndexExpression); + } + + void visit(MeshRelationAccessExpression *expr) override { + emit(ExprOpCode::MeshRelationAccessExpression); + if (expr->neighbor_idx) { + emit(MeshRelationAccessType::Access); + emit(expr->neighbor_idx); + } else { + emit(MeshRelationAccessType::Size); + } + emit(expr->mesh); + emit(expr->to_type); + emit(expr->mesh_idx); + } + + void visit(MeshIndexConversionExpression *expr) override { + emit(ExprOpCode::MeshIndexConversionExpression); + emit(expr->mesh); + emit(expr->idx_type); + emit(expr->idx); + emit(expr->conv_type); + } + void visit(Block *block) override { emit(StmtOpCode::EnterBlock); emit(static_cast(block->statements.size())); @@ -90,6 +273,8 @@ class ASTSerializer : public IRVisitor { void visit(FrontendAssertStmt *stmt) override { emit(StmtOpCode::FrontendAssertStmt); emit(stmt->cond); + emit(stmt->text); + emit(stmt->args); } void visit(FrontendSNodeOpStmt *stmt) override { @@ -205,10 +390,7 @@ class ASTSerializer : public IRVisitor { } static void run(Program *prog, IRNode *ast, std::ostream *os) { - // Temporary: using ExpressionOfflineCacheKeyGenerator, which will be - // refactored - ExpressionOfflineCacheKeyGenerator generator(prog); - ASTSerializer serializer(prog, &generator, os); + ASTSerializer serializer(prog, os); ast->accept(&serializer); serializer.emit_dependencies(); } @@ -219,11 +401,11 @@ class ASTSerializer : public IRVisitor { std::ostringstream temp_oss; auto *curr_os = this->get_ostream(); this->set_ostream(&temp_oss); - expr_printer_->set_ostream(&temp_oss); std::size_t last_size{0}; do { last_size = real_funcs_.size(); - for (auto &[func, visited] : real_funcs_) { + for (auto &[func, v] : real_funcs_) { + auto &[id, visited] = v; if (!visited) { visited = true; func->ir->accept(this); // Maybe add new func @@ -231,9 +413,9 @@ class ASTSerializer : public IRVisitor { } } while (real_funcs_.size() > last_size); this->set_ostream(curr_os); - expr_printer_->set_ostream(curr_os); emit(static_cast(real_funcs_.size())); - emit(&temp_oss); + auto real_funcs_ast_string = temp_oss.str(); + emit_bytes(real_funcs_ast_string.data(), real_funcs_ast_string.size()); // Serialize snode_trees(Temporary: using offline-cache-key of SNode) // Note: The result of serializing snode_tree_roots_ is not parsable now @@ -257,9 +439,19 @@ class ASTSerializer : public IRVisitor { void emit_bytes(const char *bytes, std::size_t len) { TI_ASSERT(os_); + if (!bytes) + return; os_->write(bytes, len); } + template + void emit(const std::vector &v) { + emit(static_cast(v.size())); + for (const auto &e : v) { + emit(e); + } + } + template void emit(const std::unordered_map &map) { emit(static_cast(map.size())); @@ -284,11 +476,6 @@ class ASTSerializer : public IRVisitor { } } - void emit(std::ostream *os) { - TI_ASSERT(os_ && os); - *os_ << os->rdbuf(); - } - void emit(const std::string &str) { std::size_t size = str.size(); std::size_t offset = string_pool_.size(); @@ -297,29 +484,36 @@ class ASTSerializer : public IRVisitor { emit(offset); } - void emit(SNodeOpType type) { - emit_pod(type); - } - - void emit(SNode *snode) { - TI_ASSERT(snode); - TI_ASSERT(prog_); - emit(static_cast(snode->get_snode_tree_id())); - emit(static_cast(snode->id)); - auto *root = prog_->get_snode_root(snode->get_snode_tree_id()); - snode_tree_roots_.insert(root); - } - - void emit(mesh::MeshElementType type) { - emit_pod(type); + void emit(Function *func) { + TI_ASSERT(func); + auto iter = real_funcs_.find(func); + if (iter != real_funcs_.end()) { + emit(iter->second.first); + } else { + auto [iter, ok] = real_funcs_.insert({func, {real_funcs_.size(), false}}); + TI_ASSERT(ok); + emit(iter->second.first); + } } - void emit(mesh::MeshRelationType type) { - emit_pod(type); + void emit(const TypedConstant &val) { + emit(val.dt); + if (!val.dt->is_primitive(PrimitiveTypeID::unknown)) { + emit(val.stringify()); + } } - void emit(mesh::ConvType type) { - emit_pod(type); + void emit(SNode *snode) { + TI_ASSERT(prog_); + if (snode) { + emit(static_cast(snode->get_snode_tree_id())); + emit(static_cast(snode->id)); + auto *root = prog_->get_snode_root(snode->get_snode_tree_id()); + snode_tree_roots_.insert(root); + } else { + emit(std::numeric_limits::max()); + emit(std::numeric_limits::max()); + } } void emit(const mesh::MeshLocalRelation &r) { @@ -330,6 +524,7 @@ class ASTSerializer : public IRVisitor { } void emit(mesh::Mesh *mesh) { + TI_ASSERT(mesh); emit(mesh->num_patches); emit(mesh->num_elements); emit(mesh->patch_max_element_num); @@ -343,43 +538,25 @@ class ASTSerializer : public IRVisitor { emit(ident.id); } - void emit(const std::vector &identifiers) { - emit(static_cast(identifiers.size())); - for (const auto &id : identifiers) { - emit(id); - } - } - - void emit(PrimitiveTypeID type_id) { - emit_pod(type_id); - } - void emit(const DataType &type) { if (auto *p = type->cast()) { emit(p->type); } else { - TI_NOT_IMPLEMENTED; + auto type_str = type->to_string(); + emit(type_str); } } - void emit(StmtOpCode code) { - emit_pod(code); - } - void emit(IRNode *ir) { TI_ASSERT(ir); ir->accept(this); } void emit(const Expr &expr) { - TI_ASSERT(expr_printer_); - expr.expr->accept(expr_printer_); - } - - void emit(const std::vector &exprs) { - emit(static_cast(exprs.size())); - for (const auto &e : exprs) { - emit(e); + if (expr) { + expr.expr->accept(this); + } else { + emit(ExprOpCode::NIL); } } @@ -399,14 +576,6 @@ class ASTSerializer : public IRVisitor { emit_pod(v); } - void emit(ForLoopType type) { - emit_pod(type); - } - - void emit(SNodeAccessFlag flag) { - emit_pod(flag); - } - void emit(const MemoryAccessOptions &mem_access_options) { auto all_options = mem_access_options.get_all(); emit(static_cast(all_options.size())); @@ -419,15 +588,33 @@ class ASTSerializer : public IRVisitor { } } - void emit(ExternalFuncType type) { - emit_pod(type); +#define DEFINE_EMIT_ENUM(EnumType) \ + void emit(EnumType type) { \ + emit_pod(type); \ } + DEFINE_EMIT_ENUM(ExprOpCode); + DEFINE_EMIT_ENUM(StmtOpCode); + DEFINE_EMIT_ENUM(PrimitiveTypeID); + DEFINE_EMIT_ENUM(UnaryOpType); + DEFINE_EMIT_ENUM(BinaryOpType); + DEFINE_EMIT_ENUM(TernaryOpType); + DEFINE_EMIT_ENUM(AtomicOpType); + DEFINE_EMIT_ENUM(SNodeOpType); + DEFINE_EMIT_ENUM(ForLoopType); + DEFINE_EMIT_ENUM(SNodeAccessFlag); + DEFINE_EMIT_ENUM(MeshRelationAccessType); + DEFINE_EMIT_ENUM(ExternalFuncType); + DEFINE_EMIT_ENUM(mesh::MeshElementType); + DEFINE_EMIT_ENUM(mesh::MeshRelationType); + DEFINE_EMIT_ENUM(mesh::ConvType); + +#undef DEFINE_EMIT_ENUM + Program *prog_{nullptr}; std::ostream *os_{nullptr}; - ExpressionPrinter *expr_printer_{nullptr}; std::unordered_set snode_tree_roots_; - std::unordered_map real_funcs_; + std::unordered_map> real_funcs_; std::vector string_pool_; }; diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 694edca7c6401..7431ab7a85bfb 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -2388,8 +2388,8 @@ FunctionType CodeGenLLVM::gen() { bool needs_cache = false; const auto &config = prog->config; std::string kernel_key; - if (config.offline_cache && this->supports_offline_cache() && - !kernel->is_evaluator) { + if (config.offline_cache && !config.async_mode && + this->supports_offline_cache() && !kernel->is_evaluator) { kernel_key = get_hashed_offline_cache_key(&kernel->program->config, kernel); LlvmOfflineCacheFileReader reader(config.offline_cache_file_path); diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 7d1b463896f61..92c882bcd19ff 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -256,75 +256,5 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { } }; -// Temporary reuse ExpressionHumanFriendlyPrinter -class ExpressionOfflineCacheKeyGenerator - : public ExpressionHumanFriendlyPrinter { - public: - explicit ExpressionOfflineCacheKeyGenerator(Program *prog, - std::ostream *os = nullptr) - : ExpressionHumanFriendlyPrinter(os), prog_(prog) { - } - - void visit(GlobalVariableExpression *expr) override { - emit("#", expr->ident.name()); - if (expr->snode) { - emit("(snode=", this->get_hashed_key_of_snode(expr->snode), ')'); - } else { - emit("(dt=", expr->dt->to_string(), ')'); - } - } - - void visit(GlobalPtrExpression *expr) override { - if (expr->snode) { - emit(this->get_hashed_key_of_snode(expr->snode)); - } else { - expr->var->accept(this); - } - emit('['); - emit_vector(expr->indices.exprs); - emit(']'); - } - - void visit(SNodeOpExpression *expr) override { - emit(snode_op_type_name(expr->op_type)); - emit('(', this->get_hashed_key_of_snode(expr->snode), ", ["); - emit_vector(expr->indices.exprs); - emit(']'); - if (expr->value.expr) { - emit(' '); - expr->value->accept(this); - } - emit(')'); - } - - private: - const std::string &cache_snode_tree_key(int snode_tree_id, - std::string &&key) { - if (snode_tree_id >= snode_tree_key_cache_.size()) { - snode_tree_key_cache_.resize(snode_tree_id + 1); - } - return snode_tree_key_cache_[snode_tree_id] = std::move(key); - } - - std::string get_hashed_key_of_snode(SNode *snode) { - TI_ASSERT(snode && prog_); - auto snode_tree_id = snode->get_snode_tree_id(); - std::string res; - if (snode_tree_id < snode_tree_key_cache_.size() && - !snode_tree_key_cache_[snode_tree_id].empty()) { - res = snode_tree_key_cache_[snode_tree_id]; - } else { - auto *snode_tree_root = prog_->get_snode_root(snode_tree_id); - auto snode_tree_key = - get_hashed_offline_cache_key_of_snode(snode_tree_root); - res = cache_snode_tree_key(snode_tree_id, std::move(snode_tree_key)); - } - return res.append(std::to_string(snode->id)); - } - - Program *prog_{nullptr}; - std::vector snode_tree_key_cache_; -}; - } // namespace lang } // namespace taichi