Skip to content

Commit

Permalink
Add CanProveSinglePoint to serve previous stronger checks.
Browse files Browse the repository at this point in the history
Add symbolic recursion depth guard to avoid nested recursion.
  • Loading branch information
tqchen committed Apr 7, 2023
1 parent d447d31 commit 773030c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 21 deletions.
16 changes: 16 additions & 0 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ class IntSet : public ObjectRef {
bool IsEverything() const;
/*! \return Whether the set is a single point */
bool IsSinglePoint() const;
/*!
* \brief Check if we can prove it is a single point.
*
* Unlike IsSinglePoint, which only checks ptr equality
* this function will invoke analyzer to do stonger proofs
* but also takes longer time.
*
* Use this function in some of the primitives but do not
* use it in the inner loop of simplification.
*
* \param ana Analyzer used in the proof.
* \return Whether we can prove it is a single point
*/
bool CanProveSinglePoint(Analyzer* ana) const;
// TODO(tvm-team): update all CanProve to explicitly take
// analyzer to encourage more analyzer reuse
/*! \return Whether the set is proved to be bigger than 0 */
bool CanProvePositive() const;
/*! \return Whether the set is proved to be smaller than 0 */
Expand Down
7 changes: 7 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,13 @@ bool IntSet::IsSinglePoint() const {
return (s_int && s_int->IsSinglePoint());
}

bool IntSet::CanProveSinglePoint(Analyzer* ana) const {
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
if (!s_int) return false;
if (s_int->IsSinglePoint()) return true;
return ana->CanProveEqual(s_int->min_value, s_int->max_value);
}

bool IntSet::CanProvePositive() const {
Analyzer analyzer;
const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
Expand Down
38 changes: 22 additions & 16 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,28 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val
return CompareResult::kLE;
}
// try to use normal int set analysis to handle symbolic comparisons
IntSet iset = analyzer_->int_set(diff - make_const(diff->dtype, val));
if (iset.HasUpperBound()) {
ConstIntBound relaxed_upper_bound = analyzer_->const_int_bound(this->VisitExpr(iset.max()));
if (relaxed_upper_bound->max_value < 0) {
return CompareResult::kLT;
}
if (relaxed_upper_bound->max_value <= 0) {
return CompareResult::kLE;
}
} else if (iset.HasLowerBound()) {
ConstIntBound relaxed_lower_bound = analyzer_->const_int_bound(this->VisitExpr(iset.min()));
if (relaxed_lower_bound->min_value > 0) {
return CompareResult::kGT;
}
if (relaxed_lower_bound->min_value >= 0) {
return CompareResult::kGE;
// NOTE: symbolic set evaluation may recursively call comparison analysis
// in such cases, we skip recursive symbolic set evaluation to avoid infinite recursion
if (symbolic_set_eval_depth_ < kMaxSymbolicSetEvalDepth) {
++symbolic_set_eval_depth_;
IntSet iset = analyzer_->int_set(diff - make_const(diff->dtype, val));
--symbolic_set_eval_depth_;
if (iset.HasUpperBound()) {
ConstIntBound relaxed_upper_bound = analyzer_->const_int_bound(this->VisitExpr(iset.max()));
if (relaxed_upper_bound->max_value < 0) {
return CompareResult::kLT;
}
if (relaxed_upper_bound->max_value <= 0) {
return CompareResult::kLE;
}
} else if (iset.HasLowerBound()) {
ConstIntBound relaxed_lower_bound = analyzer_->const_int_bound(this->VisitExpr(iset.min()));
if (relaxed_lower_bound->min_value > 0) {
return CompareResult::kGT;
}
if (relaxed_lower_bound->min_value >= 0) {
return CompareResult::kGE;
}
}
}
// modular analysis
Expand Down
5 changes: 4 additions & 1 deletion src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
protected:
// counter to record recursive rewrite depth.
int recur_depth_{0};
// counter to record recursive comparison depth that invokes set analysis
int symbolic_set_eval_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;

Expand All @@ -104,7 +106,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {

// maximum number of recursion allowed during a single pass.
static const constexpr int kMaxRecurDepth = 5;

// maximum number of set eval recursion allowed during a single pass.
static const constexpr int kMaxSymbolicSetEvalDepth = 1;
/*!
* \brief try to compare x against val.
* \param x The expression to be evaluated.
Expand Down
4 changes: 3 additions & 1 deletion src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*!\ brief Internal analyzer. */
arith::Analyzer ana_;

/*!
* \brief Update read/write buffers and regions with provided buffer and region
Expand Down Expand Up @@ -318,7 +320,7 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
const tvm::arith::IntSet& range = regions[i][j];
if (range.IsSinglePoint()) {
if (range.CanProveSinglePoint(&ana_)) {
PrimExpr min = range.min();
region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1)));
} else {
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ void UpdateBlockVarDomainDimwise(
arith::IntSet required = required_region[i];
PrimExpr dim_max = max(buffer->shape[i] - 1, 0);

if (provided.IsSinglePoint() && is_const_int(provided.min())) {
ICHECK(required.IsSinglePoint() && analyzer->CanProveEqual(provided.min(), required.min()));
if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) {
ICHECK(required.CanProveSinglePoint(analyzer) &&
analyzer->CanProveEqual(provided.min(), required.min()));
continue;
}

Expand Down Expand Up @@ -515,7 +516,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const Array<IterVar>&
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
// we only support single point provided region now, which could cover most cases
for (const auto& intset : provided_region) {
if (!intset.IsSinglePoint()) return false;
if (!intset.CanProveSinglePoint(analyzer)) return false;
}
// calculate forward mapping (block vars -> provided region point)
Map<Var, Range> dom_map;
Expand Down

0 comments on commit 773030c

Please sign in to comment.