Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

ADD Unannotate Schedule Primitives #1126

Merged
merged 4 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions cinn/backends/ir_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2731,6 +2731,50 @@ TEST(IrSchedule, Annotate) {
ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr);
}

TEST(IrSchedule, Unannotate) {
Context::Global().ResetNameId();
Expr M(32);
Expr N(32);
Placeholder<float> A("A", {M, N});
auto B = Compute(
{M, N}, [&](Var i, Var j) { return A(i, j); }, "B");

auto funcs = cinn::lang::LowerVec(
"test_split_and_fuse1", CreateStages({A, B}), {A, B}, {}, {}, nullptr, common::DefaultHostTarget(), true);
ir::IRSchedule ir_sch(ir::ModuleExpr({funcs[0]->body}));
auto fused = ir_sch.Fuse("B", {0, 1});
auto block_b = ir_sch.GetBlock("B");
ir_sch.Annotate(block_b, "k1", int(64));
block_b = ir_sch.GetBlock("B");
ir_sch.Annotate(block_b, "k2", bool(true));
block_b = ir_sch.GetBlock("B");
ir_sch.Annotate(block_b, "k3", float(2.0));
block_b = ir_sch.GetBlock("B");
ir_sch.Annotate(block_b, "k4", std::string("v4"));
block_b = ir_sch.GetBlock("B");
ir_sch.Unannotate(block_b, "k1");
block_b = ir_sch.GetBlock("B");
ir_sch.Unannotate(block_b, "k2");
block_b = ir_sch.GetBlock("B");
ir_sch.Unannotate(block_b, "k3");
block_b = ir_sch.GetBlock("B");
ir_sch.Unannotate(block_b, "k4");
std::string expected_expr = R"ROC({
ScheduleBlock(root)
{
serial for (i_j_fused, 0, 1024)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind((i_j_fused / 32), (i_j_fused % 32))
B[i0, i1] = A[i0, i1]
}
}
}
})ROC";
ASSERT_EQ(utils::GetStreamCnt(ir_sch.GetModule().GetExprs().front()), expected_expr);
}

TEST(IrSchedule, ComplexIndices) {
Target target = common::DefaultHostTarget();
ir::Expr M(32);
Expand Down
18 changes: 18 additions & 0 deletions cinn/ir/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class ScheduleImpl {
Expr Rfactor(const Expr& rf_loop, int rf_axis);
Expr AddUnitLoop(const Expr& block) const;
void Annotate(const Expr& block, const std::string& key, const attr_t& value);
void Unannotate(Expr& block, const std::string& key);
void FlattenLoops(const std::vector<Expr>& loops, const bool force_flat = false);
void CopyTransformAndLoopInfo(const Expr& block, const Expr& block_target);
void CopyTransformAndLoopInfo(const std::string& block_name, const std::string& block_target_name);
Expand Down Expand Up @@ -1509,6 +1510,18 @@ void ScheduleImpl::Annotate(const Expr& block, const std::string& key, const att
this->Replace(block, copied_block);
}

void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
CHECK(block.As<ir::ScheduleBlockRealize>());
CHECK(block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>());
auto* schedule_block = block.As<ir::ScheduleBlockRealize>()->schedule_block.As<ir::ScheduleBlock>();
if (schedule_block->attrs.count(ann_key)) {
schedule_block->attrs.erase(ann_key);
} else {
LOG(WARNING) << "Can't find annotation with key: " << ann_key;
return;
}
}

void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops, const bool flat_tensor) {
CHECK_GT(loops.size(), 0) << "Loops can't be empty!";
// compute loop
Expand Down Expand Up @@ -2008,6 +2021,11 @@ void IRSchedule::Annotate(const Expr& block, const std::string& key, const attr_
LOG(FATAL) << "Value of attribute:" << key << " input unsupported data type";
}

void IRSchedule::Unannotate(Expr& block, const std::string& key) {
impl_->Unannotate(block, key);
trace_.Append(ScheduleDesc::Step("Unannotate", {{"block", std::vector<Expr>({block})}}, {{"key", key}}, {}));
AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
}

void IRSchedule::FlattenLoops(const std::vector<Expr>& loops, const bool force_flat) {
impl_->FlattenLoops(loops, force_flat);
trace_.Append(
Expand Down
7 changes: 7 additions & 0 deletions cinn/ir/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ class IRSchedule {
*/
void Annotate(const Expr& block, const std::string& key, const attr_t& value);

/*!
* \brief To cancel an annotation within a block using the key
* \param block The block to be unannotated
* \param key The attribute key
*/
void Unannotate(Expr& block, const std::string& key);

/*!
* \brief flatten the loops in one dim.
* \param loops the loops to be flatted.
Expand Down
5 changes: 5 additions & 0 deletions cinn/ir/schedule_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,11 @@ CINN_BUILD_STEP_KIND(AnnotateStringAttr)
.Attrs({"key", "value"})
.SetApplyFn(APPLY_FUNC_UNIFORM(AnnotateStringAttr));

CINN_BUILD_STEP_KIND(Unannotate)
.Inputs({"block"})
.Attrs({"key"})
.SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unannotate)));

AndPuQing marked this conversation as resolved.
Show resolved Hide resolved
CINN_BUILD_STEP_KIND(FlattenLoops)
.Inputs({"loops"})
.Attrs({"force_flat"})
Expand Down
35 changes: 35 additions & 0 deletions cinn/ir/schedule_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,5 +641,40 @@ TEST_F(TestScheduleDesc, StepKind_Annotate) {
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}

TEST_F(TestScheduleDesc, StepKind_Unannotate) {
lowered_funcs = LowerCompute({32, 128}, target);
ir::IRSchedule ir_sch = MakeIRSchedule(lowered_funcs);

auto block_b = ir_sch.GetBlock("B");
trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
ir_sch.Annotate(block_b, "k1", int(64));
trace.Append(ScheduleDesc::Step("AnnotateIntAttr",
{{"block", std::vector<Expr>({block_b})}},
{{"key", std::string("k1")}, {"value", int(64)}},
{}));

block_b = ir_sch.GetBlock("B");
trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
ir_sch.Annotate(block_b, "k2", bool(true));
trace.Append(ScheduleDesc::Step("AnnotateBoolAttr",
{{"block", std::vector<Expr>({block_b})}},
{{"key", std::string("k2")}, {"value", bool(true)}},
{}));

block_b = ir_sch.GetBlock("B");
trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
ir_sch.Unannotate(block_b, "k1");
trace.Append(
ScheduleDesc::Step("Unannotate", {{"block", std::vector<Expr>({block_b})}}, {{"key", std::string("k1")}}, {}));

block_b = ir_sch.GetBlock("B");
trace.Append(ScheduleDesc::Step("GetBlock", {}, {{"block_name", std::string("B")}}, {block_b}));
ir_sch.Unannotate(block_b, "k2");
trace.Append(
ScheduleDesc::Step("Unannotate", {{"block", std::vector<Expr>({block_b})}}, {{"key", std::string("k2")}}, {}));

CheckReplayResult(ir_sch, trace);
CheckReplayResult(ir_sch, ir_sch.GetTraceDesc());
}
} // namespace ir
} // namespace cinn