Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Fix anchor-block flow with empty design space generator #14047

Merged
merged 2 commits into from
Feb 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class TuningRecordNode : public runtime::Object {
* argument information.
*/
ObjectRef AsJSON() const;
/*!
* \brief Check if this tuning record has valid trace instructions and successful run results.
* \return The check result.
*/
bool IsValid() const;
};

/*!
Expand Down Expand Up @@ -210,7 +215,7 @@ class DatabaseNode : public runtime::Object {
*/
virtual void CommitTuningRecord(const TuningRecord& record) = 0;
/*!
* \brief Get the top K tuning records of given workload from the database.
* \brief Get the top K valid tuning records of given workload from the database.
* \param workload The workload to be searched for.
* \param top_k The number of top records to be returned.
* \return An array of top K tuning records for the given workload.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def commit_tuning_record(self, record: TuningRecord) -> None:
_ffi_api.DatabaseCommitTuningRecord(self, record) # type: ignore # pylint: disable=no-member

def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
"""Get the top K tuning records of given workload from the database.
"""Get the top K valid tuning records of given workload from the database.

Parameters
----------
Expand Down
15 changes: 15 additions & 0 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ ObjectRef TuningRecordNode::AsJSON() const {
json_args_info};
}

bool TuningRecordNode::IsValid() const {
if (!GetNumValidInstructions(trace->insts, /*remove_postproc*/ true)) {
return false;
}
if (run_secs.defined()) {
for (const auto& run_sec : run_secs.value()) {
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
if (run_sec.defined() && run_sec->value != SortTuningRecordByMeanRunSecs::kMaxMeanTime) {
return true;
}
}
}
return false;
}

TuningRecord TuningRecord::FromJSON(const ObjectRef& json_obj, const Workload& workload) {
tir::Trace trace{nullptr};
Optional<Array<FloatImm>> run_secs{nullptr};
Expand Down
12 changes: 3 additions & 9 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,7 @@ class JSONDatabaseNode : public DatabaseNode {
results.reserve(top_k);
for (const TuningRecord& record : this->tuning_records_) {
auto run_secs = record->run_secs;
if (!run_secs.defined() || run_secs.value().empty() ||
std::all_of(run_secs.value().begin(), run_secs.value().end(),
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
[](tvm::FloatImm v) {
return v.defined() &&
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
})) {
if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
Expand All @@ -146,8 +140,8 @@ class JSONDatabaseNode : public DatabaseNode {
}
}
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}
Expand Down
13 changes: 3 additions & 10 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,7 @@ class MemoryDatabaseNode : public DatabaseNode {
std::vector<TuningRecord> results;
results.reserve(records.size());
for (const TuningRecord& record : records) {
auto run_secs = record->run_secs;
if (!run_secs.defined() || run_secs.value().empty() ||
std::all_of(run_secs.value().begin(), run_secs.value().end(),
// kMaxMeanTime(1e10) is used as a stub for undefined measurement times.
[](tvm::FloatImm v) {
return v.defined() &&
v->value == SortTuningRecordByMeanRunSecs::kMaxMeanTime;
})) {
if (!record->IsValid()) {
continue;
}
if (record->workload.same_as(workload) ||
Expand All @@ -88,8 +81,8 @@ class MemoryDatabaseNode : public DatabaseNode {
return {results.begin(), results.begin() + top_k};
} else {
if (results.size() < static_cast<size_t>(top_k)) {
LOG(WARNING) << "The size of the GetTopK result is smaller than requested. There are not "
"enough valid records in the database for this workload.";
LOG(WARNING) << "Returned tuning records less than requested(" << results.size() << " of "
<< top_k << " asked).";
}
return results;
}
Expand Down
4 changes: 3 additions & 1 deletion src/meta_schedule/trace_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ void InlinePostBlocks(Schedule sch, Trace anchor_trace, Target target) {
// Spatial blocks which are not referenced in the anchor trace will be inlined here.
auto block_sref = sch->GetSRef(block);
if (IsSpatial(block_sref) && !get_block_names.count(name)) {
if (IsOutputBlock(sch->state(), block_sref, GetScopeRoot(sch->state(), block_sref, false))) {
StmtSRef scopeRoot =
(name != "root") ? GetScopeRoot(sch->state(), block_sref, false) : block_sref;
if (IsOutputBlock(sch->state(), block_sref, scopeRoot)) {
last_block_idx = inline_todos.size();
}
inline_todos.push_back(name);
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,14 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs,
std::unordered_map<const Object*, const Object*>* rv_map);

/*!
* \brief Counts the number of trace instructions.
* \param insts The instructions representing a trace.
* \param remove_postproc If postprocessing instructions are removed.
* \return Number of instructions.
*/
int GetNumValidInstructions(const Array<Instruction>& insts, bool remove_postproc);

} // namespace tir
} // namespace tvm

Expand Down
14 changes: 11 additions & 3 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def test_module_equality_ignore_ndarray():
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


def _test_anchor_tuning(target):
def _test_anchor_tuning(target, space):
data_shape = (128, 128)
weight_shape1 = (128, 128)
weight_shape2 = (128, 128)
Expand Down Expand Up @@ -756,6 +756,7 @@ def _test_anchor_tuning(target):
target=target,
params=params,
work_dir=work_dir,
space=space,
max_trials_global=4,
strategy="replay-trace",
module_equality=module_equality,
Expand All @@ -779,8 +780,15 @@ def _test_anchor_tuning(target):
np.testing.assert_allclose(ref, out, atol=1e-3)


def test_anchor_tuning_cpu():
_test_anchor_tuning("llvm --num-cores=4")
@pytest.mark.parametrize(
"space",
[
ms.space_generator.PostOrderApply(),
ms.space_generator.PostOrderApply(sch_rules=[], postprocs=[], mutator_probs={}),
],
)
def test_anchor_tuning_cpu(space):
_test_anchor_tuning("llvm --num-cores=4", space)


def test_anchor_tuning_cpu_link_params():
Expand Down