Skip to content

Commit

Permalink
Construct FusionExecutorCaches with a ready-to-run fusion (#3969)
Browse files Browse the repository at this point in the history
As requested by myself at
#3923 (comment)

This PR tries to fix the Python frontend to construct
`FusionExecutorCache`s with a ready-to-run fusion. Previously, the code
constructs FusionExecutorCache with an empty fusion and populates it
later. This is not how we normally use and test FusionExecutorCache and
is indeed problematic in some situations. For example,
`FusionExecutorCache::exact_map_` was always empty because the
constructor thought it was going to run an empty fusion. See one of the
tests fixed by this PR.

cc @samnordmann
  • Loading branch information
wujingyue authored Feb 27, 2025
1 parent 9efab67 commit 7e9bd56
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 23 deletions.
73 changes: 57 additions & 16 deletions csrc/python_frontend/fusion_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,34 @@ FusionSchedules::FusionSchedules(int64_t fusion_id)
last_user_def_scheduled_ir(nullptr),
last_user_def_executor(nullptr),
scheds_lock(),
fusion_id_{fusion_id} {
auto_gen_schedules = std::make_unique<FusionExecutorCache>(
std::make_unique<Fusion>(), fusion_id);
fusion_id_(fusion_id) {
presched_fusion_ = std::make_unique<Fusion>();
}

Fusion* FusionSchedules::preschedFusion() {
auto fusion = auto_gen_schedules->fusion();
NVF_CHECK(fusion != nullptr, "Prescheduled Fusion is unexpectedly null!");
return fusion;
if (presched_fusion_ != nullptr) {
return presched_fusion_.get();
}

// Ideally, we shouldn't have to access FusionExecutorCache::fusion() so
// FusionExecutorCache has the flexibility to modify it in place or even
// delete it. Currently, this is only needed for cloning an
// nvfuser.FusionDefinition. See exec_nvfuser's is_clonable parameter. After
// FusionDefinition.__exit__, FusionSchedules.presched_fusion_ is moved to
// FusionExecutorCache and therefore becomes null.
if (auto_gen_schedules != nullptr) {
return auto_gen_schedules->fusion();
}

NVF_THROW("Prescheduled Fusion is unexpectedly null!");
}

void FusionSchedules::createExecutorIfNotExists() {
if (auto_gen_schedules == nullptr) {
auto_gen_schedules = std::make_unique<FusionExecutorCache>(
std::move(presched_fusion_), fusion_id_);
presched_fusion_ = nullptr;
}
}

TrieNode::TrieNode(RecordFunctor* rec, TrieNode* _parent, size_t _fusion_id)
Expand Down Expand Up @@ -478,7 +497,7 @@ std::optional<TrieNode*> FusionCache::queryChildren(
FusionSchedules* FusionCache::queryFusionSchedules(size_t fusion_id) const {
NVF_CHECK(
fusion_id < fusions_.size(),
"Invalid scheduler query for id:",
"Invalid scheduler query for id: ",
fusion_id);
FusionSchedules* ptr = fusions_.at(fusion_id).get();
NVF_CHECK(ptr != nullptr, "Unexpected null FusionSchedules object.");
Expand Down Expand Up @@ -670,12 +689,25 @@ void FusionCache::serialize(std::string filename) const {
std::vector<fb_fusion_executor_cache> fb_auto_gen_schedules;
fb_auto_gen_schedules.reserve(terminal_nodes_.size());

for (auto node : terminal_nodes_) {
for (TrieNode* node : terminal_nodes_) {
if (node->getException().has_value()) {
// Skip error nodes, which don't map to any FusionSchedules in the cache.
// Without this, queryFusionSchedules creates an empty FusionSchedules
// that's not executable.
continue;
}

FusionSchedules* schedule = queryFusionSchedules(node->fusion_id);
if (schedule->auto_gen_schedules == nullptr) {
// This fusion has been created but never executed. It doesn't save us
// anything to serialize that.
continue;
}

terminal_node_idx.push_back(
map_record_functor_to_trie_node_id.at(node->record.get()));

auto schedule = queryFusionSchedules(node->fusion_id);
fb_auto_gen_schedules.emplace_back(
fb_auto_gen_schedules.push_back(
schedule->auto_gen_schedules->serialize(builder));
}

Expand Down Expand Up @@ -733,10 +765,16 @@ void FusionCache::deserialize(std::string filename) {
max_fusions_ = fusion_cache_buffer->max_fusions();

// 2. Deserialize fusions: (Fusion) and structure: (TrieNode) fields
std::generate_n(
std::back_inserter(fusions_),
fusion_cache_buffer->terminal_nodes()->size(),
[] { return std::make_unique<FusionSchedules>(); });
int64_t num_fusions = 0;
for (const auto i :
c10::irange(fusion_cache_buffer->auto_gen_schedules()->size())) {
num_fusions = std::max(
num_fusions,
fusion_cache_buffer->auto_gen_schedules()->Get(i)->fusion_id() + 1);
}
std::generate_n(std::back_inserter(fusions_), num_fusions, [] {
return std::make_unique<FusionSchedules>();
});

serde::RecordFunctorFactory record_functor_factory;

Expand Down Expand Up @@ -837,13 +875,16 @@ void FusionCache::deserialize(std::string filename) {

std::atomic<bool> detect_exception_in_thread_pool{false};
// Deserialize terminal_nodes field in the FusionCache table
for (auto idx : c10::irange(fusions_.size())) {
for (auto idx : c10::irange(fusion_cache_buffer->terminal_nodes()->size())) {
auto node_idx = fusion_cache_buffer->terminal_nodes()->Get(idx);
auto trie_node = bfs_order.at(node_idx);
terminal_nodes_.push_back(trie_node);

auto fb_fec_node = fusion_cache_buffer->auto_gen_schedules()->Get(idx);
auto fusion_schedule = queryFusionSchedules(trie_node->fusion_id);
FusionSchedules* fusion_schedule =
queryFusionSchedules(trie_node->fusion_id);
// Create an executor so the following code can deserialize it.
fusion_schedule->createExecutorIfNotExists();

if (!isOptionDisabled(DisableOption::ParallelSerde)) {
// Parallelize the deserialization of each FusionExecutorCache.
Expand Down
19 changes: 17 additions & 2 deletions csrc/python_frontend/fusion_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,20 @@ struct UserSchedule {
//! \struct FusionSchedules
//! \brief A container for auto generated and user defined schedules
//! that correspond to compiled kernels for each complete Fusion Definition.
struct FusionSchedules {
FusionSchedules(int64_t fusion_id = 0);
class FusionSchedules {
public:
explicit FusionSchedules(int64_t fusion_id = 0);

Fusion* preschedFusion();

//! Called during execution to create a FusionExecutorCache. It's created
//! during execution instead of by finalizeDefinition because
//! finalizeDefinition may be followed by finalizeMultideviceSchedule which
//! can modify presched_fusion_. The if-not-exists check is necessary because
//! multiple FusionDefinitions may map to the same FusionSchedules. In that
//! case, we want to reuse the same executor.
void createExecutorIfNotExists();

//! Schedules Automatically generated by nvFuser for dynamic inputs. (default)
//! NOTE: The FusionExecutorCache also holds the Unscheduled Fusion IR
std::unique_ptr<FusionExecutorCache> auto_gen_schedules;
Expand All @@ -116,6 +126,11 @@ struct FusionSchedules {
std::vector<int64_t> outputs_fid_;
//! Map Fusion Val to its corresponding FusionDefinition index
std::unordered_map<const Val*, int64_t> map_value_to_fid_;

private:
//! Holds the presched fusion that will be `std::move`d to a
//! FusionExecutorCache at first execution.
std::unique_ptr<Fusion> presched_fusion_;
};

//! \struct TrieNode
Expand Down
21 changes: 21 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ std::vector<DistributedTensor> FusionDefinition::execute(

KernelArgumentHolder outputs;
if (user_sched == nullptr) {
scheds->createExecutorIfNotExists();
outputs = scheds->auto_gen_schedules->runFusionWithInputs(
args, std::nullopt, args.getDeviceIndex());
} else {
Expand Down Expand Up @@ -535,6 +536,11 @@ std::string FusionDefinition::lastCudaCode(
result = user_exec->compiledKernel()->kernelString();
}
} else {
NVF_CHECK(
scheds->auto_gen_schedules != nullptr,
"Fusion ",
*id(),
" has never been executed via FusionExecutorCache.");
result = scheds->auto_gen_schedules->getMostRecentCode(intrinsic_code);
}
return result;
Expand Down Expand Up @@ -563,6 +569,11 @@ std::string FusionDefinition::cudaCodeFor(
}
}
}
NVF_CHECK(
scheds->auto_gen_schedules != nullptr,
"Fusion ",
*id(),
" has never been executed via FusionExecutorCache.");
return scheds->auto_gen_schedules->getCodeFor(args, intrinsic_code);
}

Expand All @@ -579,6 +590,11 @@ std::string FusionDefinition::lastScheduledFusionIr(
user_sched_ir->print(ss, tensor_transforms);
result = ss.str();
} else {
NVF_CHECK(
scheds->auto_gen_schedules != nullptr,
"Fusion ",
*id(),
" has never been executed via FusionExecutorCache.");
result =
scheds->auto_gen_schedules->getMostRecentScheduledIr(tensor_transforms);
}
Expand Down Expand Up @@ -607,6 +623,11 @@ std::string FusionDefinition::scheduledFusionIrFor(
return ss.str();
}
}
NVF_CHECK(
scheds->auto_gen_schedules != nullptr,
"Fusion ",
*id(),
" has never been executed via FusionExecutorCache.");
return scheds->auto_gen_schedules->getScheduledIrFor(args, tensor_transforms);
}

Expand Down
6 changes: 5 additions & 1 deletion tests/python/opinfo_input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ def cat_error_generator(op, dtype=torch.float32, requires_grad: bool = False, **
"Unexpected number of dimensions",
)
# All tensors must have same shape except for the cat dimension
shape_mismatch = (([(2, 3), (4, 5)], 0), RuntimeError, "Tried to bind to a value")
shape_mismatch = (
([(2, 3), (4, 5)], 0),
RuntimeError,
"a conflict was found with 2 different sizes",
)

error_cases = [
empty_input_tensors,
Expand Down
16 changes: 12 additions & 4 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,9 +1803,13 @@ def schedule(self):
for fd in test_defs:
# Attempting to get the cuda code for an un-executed FusionDefinition
# should trigger a RuntimeError and not a segfault
with self.assertRaisesRegex(RuntimeError, "Invalid fusion definition!"):
with self.assertRaisesRegex(
RuntimeError, "(Invalid fusion definition!|never been executed)"
):
_ = fd.last_cuda_code()
with self.assertRaisesRegex(RuntimeError, "Invalid fusion definition!"):
with self.assertRaisesRegex(
RuntimeError, "(Invalid fusion definition!|never been executed)"
):
_ = fd.last_scheduled_fusion_ir()
# Only make this check for function based definitions
if hasattr(super(type(self), self), "definition"):
Expand Down Expand Up @@ -1840,9 +1844,13 @@ def schedule(self):

# Attempt to get strings for inputs that do not heuristically match
# and a new fusion has not been compiled
with self.assertRaisesRegex(RuntimeError, "Fusion is not compiled!"):
with self.assertRaisesRegex(
RuntimeError, "(not compiled|never been executed)"
):
_ = fd.cuda_code_for(big_inputs)
with self.assertRaisesRegex(RuntimeError, "Fusion is not compiled!"):
with self.assertRaisesRegex(
RuntimeError, "(not compiled|never been executed)"
):
_ = fd.scheduled_fusion_ir_for(big_inputs)

# It is necessary to reset the Fusion Cache
Expand Down

0 comments on commit 7e9bd56

Please sign in to comment.