diff --git a/paddle/cinn/ir/schedule/impl/base.cc b/paddle/cinn/ir/schedule/impl/base.cc index e2caa374b69be..c46d65888d5c9 100644 --- a/paddle/cinn/ir/schedule/impl/base.cc +++ b/paddle/cinn/ir/schedule/impl/base.cc @@ -36,16 +36,36 @@ namespace cinn { namespace ir { void DyScheduleImpl::MergeExprs() { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "MergeExprs"; + std::ostringstream os; auto exprs = this->GetModule().GetExprs(); if (exprs.size() == 1U) return; - CHECK(exprs[0].As()); - CHECK_EQ(exprs[0].As()->stmts.size(), 1U); - CHECK(exprs[0].As()->stmts[0].As()); - CHECK(exprs[0] - .As() - ->stmts[0] - .As() - ->schedule_block.As()); + if (!exprs[0].As()) { + os << "Expr[0] of module_expr should be a Block!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (exprs[0].As()->stmts.size() != 1U) { + os << "Expr[0] of module_expr should have only one stmt!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!exprs[0].As()->stmts[0].As()) { + os << "Expr[0] of module_expr should be Block with only one stmt which is " + "a " + "ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!exprs[0] + .As() + ->stmts[0] + .As() + ->schedule_block.As()) { + os << "Expr[0] of module_expr should be Block with only one stmt which is " + "a " + "ScheduleBlockRealize with a defined ScheduleBlock!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + std::vector merged_block; merged_block.push_back(exprs[0] .As() @@ -83,66 +103,122 @@ void DyScheduleImpl::MergeExprs() { VLOG(3) << "After merging, exprs[0] is : " << exprs[0]; exprs.erase(exprs.begin() + 1, exprs.end()); this->SetExprs(exprs); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } bool DyScheduleImpl::HasBlock(const std::string& block_name) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "HasBlock"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::HasBlock(exprs, block_name); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } std::vector DyScheduleImpl::GetLoops(const Expr& block) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetLoops"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetLoops(exprs, block); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } std::vector DyScheduleImpl::GetLoops( const std::string& block_name) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetLoops"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetLoops(exprs, block_name); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } std::vector DyScheduleImpl::GetAllBlocks() const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetAllBlocks"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetAllBlocks(exprs); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } std::vector DyScheduleImpl::GetChildBlocks(const Expr& expr) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetChildBlocks"; + std::ostringstream os; return analyzer::GetChildBlocks(expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::GetBlock(const std::string& block_name) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetBlock"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetBlock(exprs, block_name); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::GetRootBlock(const Expr& expr) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetRootBlock"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetRootBlock(exprs, expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } DeviceAPI DyScheduleImpl::GetDeviceAPI() const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "GetDeviceAPI"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::GetDeviceAPI(exprs); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Annotate(const Expr& block, const std::string& key, const attr_t& value) { - CHECK(block.As()); - CHECK(block.As() - ->schedule_block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Annotate"; + std::ostringstream os; + if (!block.As()) { + os << "Expr param(block) must be a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!block.As() + ->schedule_block.As()) { + os << "Expr param(block) must be a ScheduleBlockRealize with a " + "defined ScheduleBlock!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto copied_block = ir::ir_utils::IRCopy(block); auto* schedule_block = copied_block.As() ->schedule_block.As(); schedule_block->attrs.emplace(key, value); this->Replace(block, copied_block); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) { // NOLINT - CHECK(block.As()); - CHECK(block.As() - ->schedule_block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Unannotate"; + std::ostringstream os; + if (!block.As()) { + os << "Expr param(block) must be a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!block.As() + ->schedule_block.As()) { + os << "Expr param(block) must be a ScheduleBlockRealize with " + "a defined ScheduleBlock!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto* schedule_block = block.As() ->schedule_block.As(); if (schedule_block->attrs.count(ann_key)) { @@ -151,14 +227,31 @@ void DyScheduleImpl::Unannotate(Expr& block, LOG(WARNING) << "Can't find annotation with key: " << ann_key; return; } + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target) { - CHECK(block.As()); - CHECK(block_target.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "CopyTransformAndLoopInfo"; + std::ostringstream os; + + if (!block.As()) { + os << "Expr param(block) must be a " + "ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!block_target.As()) { + os << "Expr param(block_target) must be a " + "ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } auto exprs = this->GetModule().GetExprs(); - CHECK_EQ(exprs.size(), 1U); + if (exprs.size() != 1U) { + os << "Size of exprs of current module must be 1!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto expr = exprs[0]; auto vars = block.As() ->schedule_block.As() @@ -171,8 +264,12 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, block_target.As()->iter_values; std::vector new_iter_values; for (int i = 0; i < vars.size() && i < vars_target.size(); ++i) { - CHECK(vars[i]->upper_bound.defined() && - vars_target[i]->upper_bound.defined()); + if (!(vars[i]->upper_bound.defined() && + vars_target[i]->upper_bound.defined())) { + os << "Upper bound of iter_vars in both Expr param(block) and Expr " + "param(block_target) must be defined!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } if (vars[i]->upper_bound.is_constant() && vars_target[i]->upper_bound.is_constant() && vars[i]->upper_bound.get_constant() == @@ -185,11 +282,12 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, } } - if (new_iter_values.empty()) - LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source " - "and target is not equal! " - << vars[0]->upper_bound << " v.s " - << vars_target[0]->upper_bound; + if (new_iter_values.empty()) { + os << "Cannot CopyTransformAndLoopInfo since shape[0] of source " + "and target is not equal! " + << vars[0]->upper_bound << " v.s " << vars_target[0]->upper_bound; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } int changed_loop_num = new_iter_values.size(); std::set used_target_loop_vars; @@ -200,7 +298,12 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, return x->as_var(); }); } - CHECK(!used_target_loop_vars.empty()); + if (used_target_loop_vars.empty()) { + os << "Cannot CopyTransformAndLoopInfo since there is no loop var in the " + "new_iter_values!"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + std::vector used_target_loops; auto expr_copy = ir::ir_utils::IRCopy(expr); for (auto& var : used_target_loop_vars) { @@ -211,7 +314,12 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, Contains(*x, block_target); }, true); - CHECK_EQ(find_loop_var.size(), 1U); + if (find_loop_var.size() != 1U) { + os << "Number of loop with iter_var which is used in " + "ScheduleBlockRealize for indexing in Exprs[0] of module_exprs " + "must be 1!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } used_target_loops.push_back(*find_loop_var.begin()); VLOG(3) << "used_target_loops push_back " << used_target_loops.back(); } @@ -220,7 +328,10 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); }); for (int i = new_iter_values.size(); i < old_iter_values.size(); ++i) { - CHECK(old_iter_values[i].as_var()); + if (!old_iter_values[i].as_var()) { + os << "iter_vars[" << i << "] in Expr param(block) must be vars!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } new_iter_values.push_back(old_iter_values[i]); } Expr new_loop; @@ -230,7 +341,12 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, new_loop = ir::ir_utils::IRCopy(block); new_loop.As()->iter_values = new_iter_values; } else { - CHECK(old_iter_values[changed_loop_num].as_var()); + if (!old_iter_values[changed_loop_num].as_var()) { + os << "iter_vars[" << changed_loop_num + << "] in Expr param(block) must be vars!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto old_var = old_iter_values[changed_loop_num].as_var_ref(); auto find_partial_loop = ir::ir_utils::CollectIRNodesWithoutTensor( expr, @@ -240,18 +356,31 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, Contains(*x, block); }, true); - CHECK_EQ(find_partial_loop.size(), 1U); + if (find_partial_loop.size() != 1U) { + os << "Number of loop with iter_var which is " << old_var->name + << " should be 1 in Exprs[0] of module_expr!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } new_loop = ir::ir_utils::IRCopy(*find_partial_loop.begin()); auto find_schedule_block = ir::ir_utils::CollectIRNodesWithoutTensor( new_loop, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(find_schedule_block.size(), 1U); + if (find_schedule_block.size() != 1U) { + os << "Number of ScheduleBlockRealize in partial_loop should be 1!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + Expr sch_block = (*find_schedule_block.begin()); sch_block.As()->iter_values = new_iter_values; } VLOG(3) << "new_loop is : " << new_loop; - CHECK(!used_target_loops.empty()); + if (used_target_loops.empty()) { + os << "Cannot CopyTransformAndLoopInfo since there is no loop which use " + "vars in the new_iter_values in Expr[0] of module_expr!"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + Expr res; if (used_target_loops.size() == 1) { auto for_loop = used_target_loops[0].As(); @@ -271,28 +400,44 @@ void DyScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, } VLOG(3) << "res is : " << res; std::vector all_loops = this->GetLoops(block); - CHECK(!all_loops.empty()); + if (all_loops.empty()) { + os << "Cannot CopyTransformAndLoopInfo since there is no loop in Expr " + "param(block)!"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } this->Replace(all_loops[0], res); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::CopyTransformAndLoopInfo( const std::string& block_name, const std::string& block_target_name) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "CopyTransformAndLoopInfo"; + std::ostringstream os; auto block = this->GetBlock(block_name); auto block_target = this->GetBlock(block_target_name); this->CopyTransformAndLoopInfo(block, block_target); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::SampleCategorical( utils::LinearRandomEngine::StateType* rand_seed, const std::vector& candidates, const std::vector& probs) { - // check two sizes - CHECK_EQ(candidates.size(), probs.size()) - << "candidates and probs must have same size."; + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "SampleCategorical"; + std::ostringstream os; + if (candidates.size() != probs.size()) { + os << "vector params(candidates) and vector prama(probs) must " + "have same size in SampleCategorical!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + int seed_idx = utils::SampleDiscreteFromDistribution(probs, rand_seed); auto result = candidates[seed_idx]; Expr result_expr(result); return result_expr; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } std::vector DyScheduleImpl::SamplePerfectTile( @@ -300,13 +445,29 @@ std::vector DyScheduleImpl::SamplePerfectTile( const Expr& loop, int n, int max_innermost_factor) { - CHECK(loop.As()) - << "Expr param of SamplePerfectTile should be a For loop"; - CHECK_GE(n, 2) << "The number of tile factors should be at least 2"; - CHECK_GE(max_innermost_factor, 1) - << "The max innermost factor should be at least 1"; - CHECK(cinn::common::is_zero(loop.As()->min)) - << "The For loop should start from 0"; + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "SamplePerfectTile"; + std::ostringstream os; + if (!loop.As()) { + os << "Expr param(loop) should be a For loop"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (n < 2) { + os << "The number of tile factors should be at least 2"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (max_innermost_factor < 1) { + os << "The max innermost factor should be at least 1"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (!cinn::common::is_zero(loop.As()->min)) { + os << "The For loop should start from 0"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + int loop_extent = GetLoopExtent(loop); std::vector innermost_factors; for (int i = max_innermost_factor; i >= 1; --i) { @@ -314,7 +475,10 @@ std::vector DyScheduleImpl::SamplePerfectTile( innermost_factors.push_back(i); } } - CHECK(!innermost_factors.empty()) << "No innermost factor found"; + if (innermost_factors.empty()) { + os << "No innermost factor found"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } int innermost_factor = innermost_factors[utils::SampleUniformInt( 0, innermost_factors.size(), rand_seed)]; auto result = SampleTile(rand_seed, n - 1, loop_extent / innermost_factor); @@ -324,11 +488,16 @@ std::vector DyScheduleImpl::SamplePerfectTile( } result_expr.push_back(Expr(innermost_factor)); return result_expr; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::AddUnitLoop(const Expr& block) const { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "AddUnitLoop"; + std::ostringstream os; auto exprs = module_expr_.GetExprs(); return analyzer::AddUnitLoop(exprs, block); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } } // namespace ir diff --git a/paddle/cinn/ir/schedule/impl/compute_location.cc b/paddle/cinn/ir/schedule/impl/compute_location.cc index 0a4ade0d24b7f..d27eb23c1ea88 100644 --- a/paddle/cinn/ir/schedule/impl/compute_location.cc +++ b/paddle/cinn/ir/schedule/impl/compute_location.cc @@ -37,8 +37,17 @@ namespace ir { void DyScheduleImpl::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { - CHECK(block.As()); - CHECK(loop.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "ComputeAt"; + std::ostringstream os; + if (!block.As()) { + os << "Expr prama(block) should be a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!loop.As()) { + os << "Expr prama(loop) should be a For node!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } Expr root = this->GetRootBlock(block); VLOG(3) << "Begin ComputeAt of loop:\n" << loop << "\nat block:\n" << root; @@ -60,11 +69,23 @@ void DyScheduleImpl::ComputeAt(const Expr& block, this->Replace(reconstructor.loop_, reconstructor.new_loop_); VLOG(3) << "After ComputeAt, ir is:\n" << reconstructor.new_loop_; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { - CHECK(block.As()); - CHECK(loop.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "SimpleComputeAt"; + std::ostringstream os; + if (!block.As()) { + os << "Expr param(block) should be a " + "ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!loop.As()) { + os << "Expr param(loop) should be a For node!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + std::vector block_loops = this->GetLoops(block); Expr root = this->GetRootBlock(block); auto loops = GetLoopsOfExpr(loop, root); @@ -96,7 +117,11 @@ void DyScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { for (int i = 0; i < loops.size(); ++i) { VLOG(3) << i << "-th loop is:\n " << loops[i]; VLOG(3) << i << "-th block_loop:\n" << block_loops[i]; - CHECK_EQ(GetLoopExtent(loops[i]), GetLoopExtent(block_loops[i])); + if (GetLoopExtent(loops[i]) != GetLoopExtent(block_loops[i])) { + os << "Extent of loop in Expr Param(loop) and extent of loop in Expr " + "Param(block) should be equal correspondingly!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } if (block_loops[i].As()->bind_info().valid() && !loops[i].As()->bind_info().valid()) { loops[i].As()->set_bind_info( @@ -175,11 +200,15 @@ void DyScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) { this->Replace(this_loop, new_loop); VLOG(3) << "After SimpleComputeAt, ir is:\n" << new_loop; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::ReverseComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "ReverseComputeAt"; + std::ostringstream os; CHECK(block.As()); CHECK(loop.As()); Expr root = this->GetRootBlock(block); @@ -200,23 +229,40 @@ void DyScheduleImpl::ReverseComputeAt(const Expr& block, this->Replace(reconstructor.source_expr, reconstructor.target_expr); this->Replace(reconstructor.loop_, reconstructor.new_loop_); return; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::ComputeInline(const Expr& schedule_block) { - CHECK(schedule_block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "ComputeInline"; + std::ostringstream os; + if (!schedule_block.As()) { + os << "Expr param(schedule_block) should be a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + Expr root = this->GetRootBlock(schedule_block); Expr store = CheckComputeInlineValidationAndGetStore(schedule_block, root); ComputeInliner inliner(store.As()->tensor.as_tensor_ref(), store); - CHECK(inliner.BodyPatternAllowInline()); + + if (!inliner.BodyPatternAllowInline()) { + os << "Current IR can't meets the requirements of ComputeInline!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + // Create a plan that removes the block to be inlined LeafBlockRemovalPlan remove_plan( schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); remove_plan(&root); inliner(&root); return; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::ReverseComputeInline(const Expr& schedule_block) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "ReverseComputeInline"; + std::ostringstream os; Expr root = this->GetRootBlock(schedule_block); auto exprs = CheckReverseComputeInlineValidationAndGetExprs(schedule_block, root); @@ -228,13 +274,17 @@ void DyScheduleImpl::ReverseComputeInline(const Expr& schedule_block) { inlined_store, inlined_load, target_store); - CHECK(inliner.BodyPatternAllowInline()); + if (!inliner.BodyPatternAllowInline()) { + os << "Current IR can't meets the requirements of ReverseComputeInline!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } // Create a plan that removes the block to be inlined LeafBlockRemovalPlan remove_plan( schedule_block, &inliner.src_stmt, &inliner.tgt_stmt); remove_plan(&root); inliner(&root); inliner(&root); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } } // namespace ir diff --git a/paddle/cinn/ir/schedule/impl/for_type.cc b/paddle/cinn/ir/schedule/impl/for_type.cc index 2060ef580a33c..6b045fcc2b342 100644 --- a/paddle/cinn/ir/schedule/impl/for_type.cc +++ b/paddle/cinn/ir/schedule/impl/for_type.cc @@ -38,12 +38,21 @@ namespace ir { void DyScheduleImpl::MutateForType(const Expr& loop, ForType for_type, int factor) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "MutateForType"; + std::ostringstream os; auto* for_node = loop.As(); - CHECK(for_node) << "loop param must be For node! Please check."; - CHECK(for_node->is_serial()) - << "loop is not serial, current forloop type is " - << static_cast(for_node->for_type()) << ", and it cannot become " - << static_cast(for_type); + if (!for_node) { + os << "Loop param must be For node! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (!for_node->is_serial()) { + os << "Loop is not serial, current for loop type is " + << static_cast(for_node->for_type()) << ", and it can't become " + << static_cast(for_type) << "!\n"; + } + auto loop_copy = ir::ir_utils::IRCopy(loop); auto* new_for_node = loop_copy.As(); CHECK(new_for_node); @@ -56,17 +65,22 @@ void DyScheduleImpl::MutateForType(const Expr& loop, new_for_node->set_bind_info(bind_info); } this->Replace(loop, loop_copy); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Parallel(const Expr& loop) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Parallel"; + std::ostringstream os; MutateForType(loop, ForType::Parallel); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::Vectorize(const Expr& loop, int factor) { CINN_IR_SCHEDULE_BEGIN(); std::string primitive = "Vectorize"; std::ostringstream os; - CHECK_GT(factor, 0) << "vectorize factor should be more than 0"; + if (factor <= 0) { os << "vectorize factor should be more than 0\n"; throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index daa453d85a744..77781e193d22b 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -39,14 +39,25 @@ namespace ir { std::vector DyScheduleImpl::Split(const Expr& loop, const std::vector& factors) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Split"; + std::ostringstream os; + + if (!loop.As()) { + os << "Expr param(loop) must be For node! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto* for_node = loop.As(); + if (!cinn::common::is_zero(for_node->min)) { + os << "The For node must start with 0! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (factors.empty()) { + os << "The factors param of Split should not be empty! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (loop.As()->extent.is_constant()) { - CHECK(loop.As()) - << "Expr param of Split must be For node! Please check."; - auto* for_node = loop.As(); - CHECK(cinn::common::is_zero(for_node->min)) - << "The For node must start with 0! Please check."; - CHECK(for_node->extent.is_constant()) - << "The For node's extent must be constant! Please check."; int tot_extent = for_node->extent.get_constant(); VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " @@ -55,10 +66,8 @@ std::vector DyScheduleImpl::Split(const Expr& loop, << loop; std::vector processed_factors; - CINN_IR_SCHEDULE_BEGIN(); processed_factors = ValidateFactors(factors, tot_extent, this->module_expr_); - CINN_IR_SCHEDULE_END(this->err_msg_level_); int prod_size = std::accumulate(processed_factors.begin(), processed_factors.end(), 1, @@ -95,13 +104,6 @@ std::vector DyScheduleImpl::Split(const Expr& loop, VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); return splited_loops; } - CHECK(loop.As()) - << "Expr param of Split must be For node! Please check."; - auto* for_node = loop.As(); - CHECK(common::is_zero(for_node->min)) - << "The For node must start with 0! Please check."; - CHECK(!factors.empty()) - << "The factors param of Split should not be empty! Please check."; Expr tot_extent = for_node->extent; @@ -125,9 +127,12 @@ std::vector DyScheduleImpl::Split(const Expr& loop, if (factor < 1 && factor != -1) is_positive = false; if (factor == -1) ++num_minus1; }); - CHECK((num_minus1 <= 1) && is_positive) - << "The params in factors of Split on dynamic shape should contains at " - "most one '-1' and the rest of them should be positive!\n"; + + if (num_minus1 > 1 || (!is_positive)) { + os << "The params in factors of Split on dynamic shape should contains at " + "most one '-1' and the rest of them should be positive!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } std::vector new_loop_vars; Expr substitute_value(0); @@ -158,21 +163,37 @@ std::vector DyScheduleImpl::Split(const Expr& loop, this->Replace(loop, new_node); VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); return splited_loops; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } // TODO(@LiuYang): now -1 can't exsit in factors, std::vector DyScheduleImpl::Split(const Expr& loop, const std::vector& factors) { - CHECK(loop.As()) - << "Expr param of Split must be For node! Please check."; + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Split"; + std::ostringstream os; + if (!loop.As()) { + os << "Expr param(loop) must be For node! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto* for_node = loop.As(); - CHECK(common::is_zero(for_node->min)) - << "The For node must start with 0! Please check."; - CHECK(!factors.empty()) - << "The factors param of Split should not be empty! Please check."; - CHECK(!loop.As()->extent.is_constant()) - << "Can't Split a loop with constant extent but with variable in " - "factors!"; + + if (!common::is_zero(for_node->min)) { + os << "The For node must start with 0! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (factors.empty()) { + os << "The factors param of Split should not be empty! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (loop.As()->extent.is_constant()) { + os << "Can't Split a loop with constant extent but with variable in " + "factors! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + Expr tot_extent = for_node->extent; VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, " @@ -185,9 +206,11 @@ std::vector DyScheduleImpl::Split(const Expr& loop, for (auto factor : factors) prod_size = prod_size * Expr(factor); common::cas_intervals_t var_intervals = {}; cinn::common::SymbolicExprAnalyzer analyzer(var_intervals); - CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false)) - << "Product of factors can't be proved to be equal to the extent of " - "current for loop!"; + if (!analyzer.ProveEQ(tot_extent, prod_size).value_or(false)) { + os << "Product of factors can't be proved to be equal to the extent of " + "current for loop! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } std::vector new_loop_vars; Expr substitute_value(0); @@ -216,26 +239,45 @@ std::vector DyScheduleImpl::Split(const Expr& loop, this->Replace(loop, new_node); VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0); return splited_loops; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Fuse(const std::vector& loops) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Fuse"; + std::ostringstream os; + VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n"); std::vector for_nodes; std::vector loop_vars; - CHECK(!loops.empty()) - << "The loops param of Fuse should not be empty! Please check."; + if (loops.empty()) { + os << "The loops param of Fuse should not be empty! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } for (const Expr& it_loop : loops) { - CHECK(it_loop.As()) - << "Expr param of Fuse must be For node! Please check."; + if (!it_loop.As()) { + os << "Loop in vector param(loops) of Fuse must be For node! " + "Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + if (!for_nodes.empty()) { - CHECK(for_nodes.back()->body.As()) - << "The body of for node is not Block!"; - CHECK_EQ(for_nodes.back()->body.As()->stmts.size(), 1U) - << "The Block'size of for node is not 1!"; - CHECK_EQ(for_nodes.back()->body.As()->stmts[0], it_loop) - << "The For nodes in loops param of Fuse must be adjacent! Please " - "check."; + if (!for_nodes.back()->body.As()) { + os << "The body of for node is not Block! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (for_nodes.back()->body.As()->stmts.size() != 1) { + os << "The Block's size of for node is not 1! Please check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (for_nodes.back()->body.As()->stmts[0] != it_loop) { + os << "The For nodes in loops param of Fuse must be adjacent! Please " + "check!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } } for_nodes.push_back(it_loop.As()); loop_vars.push_back(it_loop.As()->loop_var); @@ -276,47 +318,77 @@ Expr DyScheduleImpl::Fuse(const std::vector& loops) { VLOG(3) << "After fuse, ir is:\n" << new_stmt; return new_stmt; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Fuse(const std::string& block_name, const std::vector& loops_index) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Fuse"; + std::ostringstream os; std::vector all_loops = this->GetLoops(block_name); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i = 0; i < loops_index.size(); ++i) { - if (i > 0) - CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) - << "Loops index in Fuse shoule be continuous!"; + if (i > 0) { + if (loops_index[i - 1] + 1 != loops_index[i]) { + os << "Loops index in Fuse shoule be continuous!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + } } for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) - << "The loop index in Fuse should be less than total loop's number."; - CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; + if (i >= static_cast(all_loops.size())) { + os << "The loop index in Fuse should be less than total loop's number!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (i < 0) { + os << "The loop index in Fuse should be >= 0!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } loops_expr.emplace_back(all_loops[i]); } return this->Fuse(loops_expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Fuse(const Expr& block, const std::vector& loops_index) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Fuse"; + std::ostringstream os; std::vector all_loops = this->GetLoops(block); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i = 0; i < loops_index.size(); ++i) { - if (i > 0) - CHECK_EQ(loops_index[i - 1] + 1, loops_index[i]) - << "Loops index in Fuse shoule be continuous!"; + if (i > 0) { + if (loops_index[i - 1] + 1 != loops_index[i]) { + os << "Loops index in Fuse shoule be continuous!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + } } for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) - << "The loop index in Fuse should be less than total loop's number."; - CHECK_GE(i, 0) << "The loop index in Fuse should be >= 0."; + if (i >= static_cast(all_loops.size())) { + os << "The loop index in Fuse should be less than total loop's number!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (i <= 0) { + os << "The loop index in Fuse should be >= 0!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } loops_expr.emplace_back(all_loops[i]); } return this->Fuse(loops_expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Reorder(const std::vector& loops) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Reorder"; + std::ostringstream os; if (loops.size() <= 1) { return Expr{nullptr}; } @@ -333,34 +405,60 @@ Expr DyScheduleImpl::Reorder(const std::vector& loops) { VLOG(4) << "After Reorder, ir is:\n" << new_loop; return new_loop; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Reorder(const std::string& block_name, const std::vector& loops_index) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Reorder"; + std::ostringstream os; + std::vector all_loops = this->GetLoops(block_name); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) - << "The loop index in Reorder should be less than total loop's number."; - CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; + if (i >= static_cast(all_loops.size())) { + os << "The loop index in Reorder should be less than total loop's " + "number!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (i < 0) { + os << "The loop index in Reorder should be >= 0!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } loops_expr.emplace_back(all_loops[i]); } return this->Reorder(loops_expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::Reorder(const Expr& block, const std::vector& loops_index) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Reorder"; + std::ostringstream os; + std::vector all_loops = this->GetLoops(block); std::vector loops_expr; loops_expr.reserve(loops_index.size()); for (int i : loops_index) { - CHECK_LT(i, (int)all_loops.size()) - << "The loop index in Reorder should be less than total loop's number."; - CHECK_GE(i, 0) << "The loop index in Reorder should be >= 0."; + if (i >= static_cast(all_loops.size())) { + os << "The loop index in Reorder should be less than total loop's " + "number!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + + if (i < 0) { + os << "The loop index in Reorder should be >= 0!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + loops_expr.emplace_back(all_loops[i]); } return this->Reorder(loops_expr); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::FlattenLoops(const std::vector& loops, diff --git a/paddle/cinn/ir/schedule/impl/reduction.cc b/paddle/cinn/ir/schedule/impl/reduction.cc index 0d6e428c9bd37..12d5df3f99932 100644 --- a/paddle/cinn/ir/schedule/impl/reduction.cc +++ b/paddle/cinn/ir/schedule/impl/reduction.cc @@ -36,6 +36,10 @@ namespace cinn { namespace ir { Expr DyScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "Rfactor"; + std::ostringstream os; + CHECKRfactorValidation(rf_loop, rf_axis); // get root ScheduleBlockRealize Expr root = GetRootBlock(rf_loop); @@ -43,16 +47,18 @@ Expr DyScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { RfCreater rf_create(root, rf_loop, rf_axis); // return new created rfactor tensor return rf_create.CreateRfAllStmts(); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { + CINN_IR_SCHEDULE_BEGIN() std::string primitive = "FactorizeReduction"; + std::ostringstream os; // Get child block of the rf_loop and check. std::vector blocks = GetChildBlocks(rf_loop); if (blocks.size() != 1) { - std::ostringstream os; os << "The rf_loop is required to have only one child block, but got " - << blocks.size() << std::endl; + << blocks.size() << "!\n"; throw IRScheduleErrorHandler(primitive, os.str(), this->module_expr_); } Expr original_block = blocks.at(0); @@ -62,7 +68,10 @@ Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { // Collect the loops of the block. // Construct a map from loop var names to corresponding loops. std::vector original_loops = this->GetLoops(original_block); - CHECK_GT(original_loops.size(), 0); + if (original_loops.size() <= 0) { + os << "The size of original_loops should be great than 0!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), this->module_expr_); + } VLOG(3) << "before FactorizeReduction, original computational body of the " "reduction is:\n" << original_loops[0]; @@ -116,6 +125,7 @@ Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { "reduction is:\n" << new_computational_body; return rf_tensor; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } } // namespace ir diff --git a/paddle/cinn/ir/schedule/impl/storage.cc b/paddle/cinn/ir/schedule/impl/storage.cc index 5c8b2bb8305c3..a35683ba68138 100644 --- a/paddle/cinn/ir/schedule/impl/storage.cc +++ b/paddle/cinn/ir/schedule/impl/storage.cc @@ -38,11 +38,24 @@ namespace ir { Expr DyScheduleImpl::CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type) { - CHECK(block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "CacheRead"; + std::ostringstream os; + + if (!block.As()) { + os << "Expr param(block) is not a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto root = GetRootBlock(block); ChangeBodyToBlock::Change(&root); Expr read_expr = GetNthAccessExpr(block, read_buffer_index, false); - CHECK(read_expr.As()); + + if (!read_expr.As()) { + os << "The read_expr is not a Load!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto tensor_indices = read_expr.As()->indices; CacheBlockInfo info; info.read_tensor = read_expr.As()->tensor.as_tensor_ref(); @@ -61,16 +74,30 @@ Expr DyScheduleImpl::CacheRead(const Expr& block, ->schedule_block.As() ->body); return new_block; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } Expr DyScheduleImpl::CacheWrite(const Expr& block, int write_buffer_index, const std::string& memory_type) { - CHECK(block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "CacheWrite"; + std::ostringstream os; + + if (!block.As()) { + os << "Expr param(block) is not a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto root = GetRootBlock(block); ChangeBodyToBlock::Change(&root); Expr write_expr = GetNthAccessExpr(block, write_buffer_index, true); - CHECK(write_expr.As()); + + if (!write_expr.As()) { + os << "The write_expr is not a Store!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + Tensor write_tensor = write_expr.As()->tensor.as_tensor_ref(); auto tensor_indices = write_expr.As()->indices; CacheBlockInfo info; @@ -99,7 +126,10 @@ Expr DyScheduleImpl::CacheWrite(const Expr& block, }, true); - CHECK(info.write_tensor->buffer.defined()); + if (!info.write_tensor->buffer.defined()) { + os << "The buffer of current write_tensor is not defined!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } // Replace buffer auto all_tensors = @@ -115,28 +145,52 @@ Expr DyScheduleImpl::CacheWrite(const Expr& block, } } - CHECK_EQ(find_cache_block.size(), 1U); + if (find_cache_block.size() != 1U) { + os << "Size of find_cache_block is not 1!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } return *find_cache_block.begin(); + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) { - CHECK(ir_node.As() || ir_node.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "SyncThreads"; + std::ostringstream os; + + if (!(ir_node.As() || ir_node.As())) { + os << "Expr param(ir_node) should be a ScheduleBlockRealize or For!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto root = GetRootBlock(ir_node); ChangeBodyToBlock::Change(&root); Expr sync_threads = runtime::IntrinsicCall(Void(), "__syncthreads", {}); InsertExpr::Insert(ir_node, sync_threads, after_node, &root); return; + CINN_IR_SCHEDULE_END(this->err_msg_level_); } void DyScheduleImpl::SetBuffer(Expr& block, // NOLINT const std::string& memory_type, bool fixed) { - CHECK(block.As()); + CINN_IR_SCHEDULE_BEGIN(); + std::string primitive = "SetBuffer"; + std::ostringstream os; + if (!block.As()) { + os << "Expr param(block) is not a ScheduleBlockRealize!\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto find_tensor = ir::ir_utils::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); - CHECK_EQ(find_tensor.size(), 1U) - << "One block should only have one Store node!(except for root block)"; + + if (find_tensor.size() != 1U) { + os << "One block should only have one Store node!(except for root block)\n"; + throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); + } + auto& tensor = (*find_tensor.begin()).As()->tensor; tensor.as_tensor_ref()->WithBuffer( memory_type, "_" + tensor.as_tensor_ref()->name + "_temp_buffer"); @@ -151,7 +205,6 @@ void DyScheduleImpl::SetBuffer(Expr& block, // NOLINT tensor.as_tensor_ref()->name + "__reduce_init"); }); for (auto& t : find_tensor) { - CHECK(t.as_tensor()); t.as_tensor_ref()->Bind(tensor.as_tensor_ref()->buffer); } } @@ -164,6 +217,7 @@ void DyScheduleImpl::SetBuffer(Expr& block, // NOLINT auto root = GetRootBlock(block); mutator(&root); } + CINN_IR_SCHEDULE_END(this->err_msg_level_); } } // namespace ir } // namespace cinn