Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Mar 5, 2025
1 parent 0b6f47e commit 5b2260e
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
14 changes: 14 additions & 0 deletions csrc/ir/container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,18 @@ void IrContainer::assumeNonNegative(Val* val) {
axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal()));
}

void IrContainer::removeStatementsConstructedAfter(
int64_t prev_num_exprs,
int64_t prev_num_vals) {
while (static_cast<int64_t>(exprs_up_.size()) > prev_num_exprs) {
exprs_.erase(exprs_up_.back().get());
exprs_up_.pop_back();
}

while (static_cast<int64_t>(vals_up_.size()) > prev_num_vals) {
vals_.erase(vals_up_.back().get());
vals_up_.pop_back();
}
}

} // namespace nvfuser
12 changes: 12 additions & 0 deletions csrc/ir/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ class IrContainer : public PolymorphicBase {
return vals_;
}

int64_t numExprs() const noexcept {
return static_cast<int64_t>(exprs_up_.size());
}

int64_t numVals() const noexcept {
return static_cast<int64_t>(vals_up_.size());
}

// Shortcuts for frequently used vals
NVF_API Val* zeroVal();
NVF_API Val* oneVal();
Expand All @@ -144,6 +152,10 @@ class IrContainer : public PolymorphicBase {
void assumePositive(Val* val);
void assumeNonNegative(Val* val);

void removeStatementsConstructedAfter(
int64_t prev_num_exprs,
int64_t prev_num_vals);

protected:
static IrCloner copy(const IrContainer* from, IrContainer* to);

Expand Down
21 changes: 21 additions & 0 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,26 @@ std::pair<Val*, bool> computeIndex(

return known_indices.at(id);
}

class StatementGuard {
public:
StatementGuard() : fusion_(FusionGuard::getCurFusion()) {
// Trigger lazy initialization of axioms.
fusion_->axioms();
prev_num_exprs_ = fusion_->numExprs();
prev_num_vals_ = fusion_->numVals();
}

~StatementGuard() {
fusion_->removeStatementsConstructedAfter(prev_num_exprs_, prev_num_vals_);
}

private:
Fusion* fusion_;
int64_t prev_num_exprs_;
int64_t prev_num_vals_;
};

} // namespace

bool haveDifferentShardings(
Expand Down Expand Up @@ -374,6 +394,7 @@ bool haveDifferentShardings(

NVF_ERROR(producer->fusion() == consumer->fusion());
FusionGuard fg(producer->fusion());
StatementGuard sg;

// FIXME: can we reuse IterDomain* which is also a Val*?
// FIXME: remove Val* and Expr*
Expand Down

0 comments on commit 5b2260e

Please sign in to comment.