diff --git a/cinn/backends/ir_schedule_test.cc b/cinn/backends/ir_schedule_test.cc index 487f8c0a05..a3f2ea8c3f 100644 --- a/cinn/backends/ir_schedule_test.cc +++ b/cinn/backends/ir_schedule_test.cc @@ -2731,6 +2731,50 @@ TEST(IrSchedule, Annotate) { ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); } +TEST(IrSchedule, Unannotate) { + Context::Global().ResetNameId(); + Expr M(32); + Expr N(32); + Placeholder A("A", {M, N}); + auto B = Compute( + {M, N}, [&](Var i, Var j) { return A(i, j); }, "B"); + + auto funcs = cinn::lang::LowerVec( + "test_split_and_fuse1", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true); + ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body})); + auto fused = ir_sch.Fuse("B", {0, 1}); + auto block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k1", int(64)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k2", bool(true)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k3", float(2.0)); + block_b = ir_sch.GetBlock("B"); + ir_sch.Annotate(block_b, "k4", std::string("v4")); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k1"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k2"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k3"); + block_b = ir_sch.GetBlock("B"); + ir_sch.Unannotate(block_b, "k4"); + std::string expected_expr = R"ROC({ + ScheduleBlock(root) + { + serial for (i_j_fused, 0, 1024) + { + ScheduleBlock(B) + { + i0, i1 = axis.bind((i_j_fused / 32), (i_j_fused % 32)) + B[i0, i1] = A[i0, i1] + } + } + } +})ROC"; + ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr); +} + TEST(IrSchedule, ComplexIndices) { Target target = common::DefaultHostTarget(); ir::Expr M(32); diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index fff66d7fec..f83cba4ce3 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -94,6 +94,7 @@ class ScheduleImpl { Expr Rfactor(const Expr& rf_loop, int rf_axis); Expr AddUnitLoop(const Expr& block) const; void Annotate(const Expr& block, const std::string& key, const attr_t& value); + void Unannotate(Expr& block, const std::string& key); void FlattenLoops(const std::vector& loops, const bool force_flat = false); void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target); void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name); @@ -1509,6 +1510,18 @@ void ScheduleImpl::Annotate(const Expr& block, const std::string& key, const att this->Replace(block, copied_block); } +void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) { + CHECK(block.As()); + CHECK(block.As()->schedule_block.As()); + auto* schedule_block = block.As()->schedule_block.As(); + if (schedule_block->attrs.count(ann_key)) { + schedule_block->attrs.erase(ann_key); + } else { + LOG(WARNING) << "Can't find annotation with key: " << ann_key; + return; + } +} + void ScheduleImpl::FlattenLoops(const std::vector& loops, const bool flat_tensor) { CHECK_GT(loops.size(), 0) << "Loops can't be empty!"; // compute loop @@ -2008,6 +2021,11 @@ void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_ LOG(FATAL) << "Value of attribute:" << key << " input unsupported data type"; } +void IRSchedule::Unannotate(Expr& block, const std::string& key) { + impl_->Unannotate(block, key); + trace_.Append(ScheduleDesc::Step("Unannotate", {{"block", std::vector({block})}}, {{"key", key}}, {})); +} + void IRSchedule::FlattenLoops(const std::vector& loops, const bool force_flat) { impl_->FlattenLoops(loops, force_flat); trace_.Append( diff --git a/cinn/ir/ir_schedule.h b/cinn/ir/ir_schedule.h index 38a7026a98..36bdc64426 100644 --- a/cinn/ir/ir_schedule.h +++ b/cinn/ir/ir_schedule.h @@ -335,6 +335,13 @@ class IRSchedule { */ void Annotate(const Expr& block, const std::string& key, const attr_t& value); + /*! + * \brief To cancel an annotation within a block using the key + * \param block The block to be unannotated + * \param key The attribute key + */ + void Unannotate(Expr& block, const std::string& key); + /*! * \brief flatten the loops in one dim. * \param loops the loops to be flatted. diff --git a/cinn/ir/schedule_desc.cc b/cinn/ir/schedule_desc.cc index ebd2368a7e..91cad90647 100644 --- a/cinn/ir/schedule_desc.cc +++ b/cinn/ir/schedule_desc.cc @@ -452,6 +452,11 @@ CINN_BUILD_STEP_KIND(AnnotateStringAttr) .Attrs({"key", "value"}) .SetApplyFn(APPLY_FUNC_UNIFORM(AnnotateStringAttr)); +CINN_BUILD_STEP_KIND(Unannotate) + .Inputs({"block"}) + .Attrs({"key"}) + .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unannotate))); + CINN_BUILD_STEP_KIND(FlattenLoops) .Inputs({"loops"}) .Attrs({"force_flat"}) diff --git a/cinn/ir/schedule_desc_test.cc b/cinn/ir/schedule_desc_test.cc index b33b752cbf..0dd092e07d 100644 --- a/cinn/ir/schedule_desc_test.cc +++ b/cinn/ir/schedule_desc_test.cc @@ -641,5 +641,40 @@ TEST_F(TestScheduleDesc, StepKind_Annotate) { CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); } +TEST_F(TestScheduleDesc, StepKind_Unannotate) { + lowered_funcs = LowerCompute({32, 128}, target); + ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs); + + auto block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k1", int(64)); + trace.Append(ScheduleDesc::Step("AnnotateIntAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k1")}, {"value", int(64)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Annotate(block_b, "k2", bool(true)); + trace.Append(ScheduleDesc::Step("AnnotateBoolAttr", + {{"block", std::vector({block_b})}}, + {{"key", std::string("k2")}, {"value", bool(true)}}, + {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Unannotate(block_b, "k1"); + trace.Append( + ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k1")}}, {})); + + block_b = ir_sch.GetBlock("B"); + trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b})); + ir_sch.Unannotate(block_b, "k2"); + trace.Append( + ScheduleDesc::Step("Unannotate", {{"block", std::vector({block_b})}}, {{"key", std::string("k2")}}, {})); + + CheckReplayResult(ir_sch, trace); + CheckReplayResult(ir_sch, ir_sch.GetTraceDesc()); +} } // namespace ir } // namespace cinn