Skip to content

Commit

Permalink
[PASS] UnrollLoop, isolate arithmetic module. (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Feb 5, 2017
1 parent d89917b commit e42cc11
Show file tree
Hide file tree
Showing 15 changed files with 261 additions and 161 deletions.
10 changes: 9 additions & 1 deletion include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);

/*!
* \brief inline all calls of f in stmt.
Expand Down Expand Up @@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down Expand Up @@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);


} // namespace ir
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< "but request arg" << i;
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}

Expand Down
1 change: 0 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def build(sch,
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")

if record_codes is not None:
output_ssa = False
Expand Down
1 change: 1 addition & 0 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
- api API functionr registration
- lang The definition of DSL related data structure
- schedule The operations on the schedule graph before converting to IR.
- arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure
- runtime Minimum runtime related codes.
10 changes: 10 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/api_registry.h>

namespace tvm {
Expand All @@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
}
});

TVM_REGISTER_API(_pass_PostOrderVisit)
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
f(n);
});
});

// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
Expand All @@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
Expand Down
10 changes: 5 additions & 5 deletions src/schedule/compute_expr.h → src/arithmetic/compute_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
* \brief Utility integer expression with quick eager simplification.
* This is weaker than Simplify but can be done Eagerly.
*/
#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
#define TVM_SCHEDULE_COMPUTE_EXPR_H_
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_

#include <tvm/ir.h>
#include <pass/Interval.h>

namespace tvm {
namespace schedule {
namespace arith {

using Halide::Internal::add_would_overflow;
using Halide::Internal::sub_would_overflow;
Expand Down Expand Up @@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
return Halide::Internal::Interval::make_min(a, b);
}

} // namespace schedule
} // namespace arith
} // namespace tvm
#endif // TVM_SCHEDULE_COMPUTE_EXPR_H_
#endif // TVM_ARITHMETIC_COMPUTE_EXPR_H_
96 changes: 12 additions & 84 deletions src/schedule/int_set.cc → src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* Copyright (c) 2016 by Contributors
* \file int_set_impl.cc
* Copyright (c) 2017 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
Expand All @@ -10,7 +10,7 @@
#include "./compute_expr.h"

namespace tvm {
namespace schedule {
namespace arith {

using Halide::Internal::Interval;

Expand Down Expand Up @@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
return (s_int && s_int->i.is_single_point());
}

Expr IntSet::point_value() const {
const IntervalSet* s_int = (*this).as<IntervalSet>();
CHECK(s_int && s_int->i.is_single_point());
return s_int->i.min;
}

IntSet IntSet::everything() {
return IntervalSet::make(Interval::everything());
}
Expand All @@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
}

// Check if a is created from b.
inline bool MatchRange(const IntSet& a,
const Range& b) {
bool IntSet::match_range(const Range& b) const {
const IntSet& a = *this;
const IntervalSet* a_int = a.as<IntervalSet>();
if (!a_int) return false;
const Interval& i = a_int->i;
Expand Down Expand Up @@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
return CombineSets<OP>(a, b);
}

// Implementation of Evaluations and passing.
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent) {
if (dom_map.count(s->outer) &&
dom_map.count(s->inner) &&
dom_map.count(s->parent) &&
MatchRange(outer, dom_map.at(s->outer)) &&
MatchRange(inner, dom_map.at(s->inner))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr factor = dom_map.at(s->inner)->extent;
Expr parent_min = dom_map.at(s->parent)->min;
CHECK(outer.defined());
CHECK(inner.defined());
CHECK(factor.defined());

*parent = Combine<Add>(
Combine<Add>(
Combine<Mul>(outer, IntSet::single_point(factor)), inner),
IntSet::single_point(parent_min));
}

void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner) {
CHECK(dom_map.count(s->outer));
CHECK(dom_map.count(s->inner));
CHECK(dom_map.count(s->fused));

if (MatchRange(fused, dom_map.at(s->fused))) {
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}

Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;

const IntervalSet* fused_int = fused.as<IntervalSet>();

if (fused_int && fused_int->i.is_single_point()) {
Expr value = fused_int->i.min;
Expr factor = dom_map.at(s->inner)->extent;
Expr v_outer = value / factor;
Expr v_inner = value % factor;
if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
*outer = IntSet::single_point(v_outer);
*inner = IntSet::single_point(v_inner);
} else {
LOG(WARNING) << "use fallback inference rule in fuse";
// simply use the entire set, this rule can be enhanced.
*outer = IntSet::range(dom_map.at(s->outer));
*inner = IntSet::range(dom_map.at(s->inner));
return;
}
}


void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& rebased,
IntSet* parent) {
CHECK(dom_map.count(s->parent));
if (MatchRange(rebased, dom_map.at(s->rebased))) {
*parent = IntSet::range(dom_map.at(s->parent));
return;
}
Expr parent_min = dom_map.at(s->parent)->min;
*parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
}

// Evaluator to evalute the epxression.
class IntSetEvaluator {
public:
Expand Down Expand Up @@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});


} // namespace schedule
} // namespace arith
} // namespace tvm
75 changes: 17 additions & 58 deletions src/schedule/int_set.h → src/arithmetic/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
* \file int_set.h
* \brief Abstraction for all integer set operations.
*/
#ifndef TVM_SCHEDULE_INT_SET_H_
#define TVM_SCHEDULE_INT_SET_H_
#ifndef TVM_ARITHMETIC_INT_SET_H_
#define TVM_ARITHMETIC_INT_SET_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>

namespace tvm {
namespace schedule {
namespace arith {

// internal node container of int set.
class IntSetNode;
Expand Down Expand Up @@ -44,6 +44,18 @@ class IntSet : public NodeRef {
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
Expr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return Whether the set contains everything */
static IntSet everything();
/*!
Expand Down Expand Up @@ -88,59 +100,6 @@ IntSet EvalSet(Expr e,
IntSet EvalSet(Range r,
const Map<IterVar, IntSet>& dom_map);

/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void PassUp(const SplitNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& outer,
const IntSet& inner,
IntSet* parent);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void PassUp(const FuseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* outer,
IntSet* inner);

/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param rebased domain of rebased iteration.
* \param parent The result domain of parent iteration.
*/
void PassUp(const RebaseNode* s,
const std::unordered_map<IterVar, Range>& dom_map,
const IntSet& fused,
IntSet* parent);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
Expand All @@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
}

} // namespace schedule
} // namespace arith
} // namespace tvm

#endif // TVM_SCHEDULE_INT_SET_H_
#endif // TVM_ARITHMETIC_INT_SET_H_
22 changes: 18 additions & 4 deletions src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,24 @@ class IRInline : public IRMutator {
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
CHECK_EQ(args_.size(), op->args.size());

bool has_side_effect = false;
for (size_t i = 0; i < op->args.size(); ++i) {
if (HasSideEffect(op->args[i])) has_side_effect = true;
}

if (has_side_effect) {
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
} else {
Map<Var, Expr> vmap;
for (size_t i = 0; i < args_.size(); ++i) {
vmap.Set(args_[i], op->args[i]);
}
expr = Substitute(
Evaluate::make(expr), vmap).as<Evaluate>()->value;
}
return expr;
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/pass/simple_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
std::unordered_map<const Variable*, Expr> smap;
};

Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first->var.get()] = kv.second;
m.smap[kv.first.get()] = kv.second;
}
return m.Mutate(stmt);
}
Expand Down
Loading

0 comments on commit e42cc11

Please sign in to comment.