Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Change Schedule->Stage, Use Schedule for global schedule #8

Merged
merged 3 commits into from
Jan 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@ language: cpp

os:
- linux
- osx
# - osx

env:
# code analysis
- TASK=lint
- TASK=cpp_test
- TASK=python_test
- TASK=all_test

branches:
only:
Expand All @@ -35,6 +33,7 @@ addons:
- g++-4.8
- python-numpy
- python-nose
- python3-numpy
- python3-dev
- python3-nose
- graphviz
Expand Down
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 1ec478 to 98e8df
145 changes: 112 additions & 33 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace tvm {

// Node container for Stage
class StageNode;
// Node container for Schedule
class ScheduleNode;
// Node container for IterVarRelation
Expand All @@ -25,46 +27,48 @@ enum AttachType : int {
kScope = 3
};

/*! \brief schedule container */
class Schedule : public NodeRef {
/*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
Stage() {}
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief create a new schedule for op.
* \param op The operator in the schedule
* \param scope The scope of the schedule
*/
Schedule(Operation op, std::string scope);
explicit Stage(Operation op);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
inline const StageNode* operator->() const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline ScheduleNode* operator->();
inline StageNode* operator->();
/*!
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
Stage& set_scope(std::string scope); // NOLINT(*)
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline, attach it at parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_inline(Schedule parent); // NOLINT(*)
Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at root, attach it to its parent.
* \param parent The parent schedule to be attached to.
* \return reference to self.
*/
Schedule& compute_root(Schedule parent); // NOLINT(*)
Stage& compute_root(); // NOLINT(*)
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
Expand All @@ -73,7 +77,7 @@ class Schedule : public NodeRef {
* \param factor The split factor of the loop.
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
/*!
* \brief Split the iteration with a given outer domain,
* the outer domain must have a thread-tag.
Expand All @@ -85,24 +89,74 @@ class Schedule : public NodeRef {
* factor must be provided such that factor * outer.extent >= parent.extent.
* \return reference to self.
*/
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param inner The inner domain to be fused
* \param outer The outer domain to be fused.
* \param p_target The result target domain.
* \return reference to self.
*/
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); // NOLINT(*)
Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
* [x_outer, y_outer, x_inner, y_inner]
*
* \param x_parent The original x dimension
* \param y_parent The original y dimension
* \param p_x_outer Outer axis of x dimension
* \param p_y_outer Outer axis of y dimension
* \param p_x_inner Inner axis of x dimension
* \param p_y_inner Inner axis of y dimension
* \param x_factor The stride factor on x axis
* \param y_factor The stride factor on y axis
* \return reference to self.
*/
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor);
};

/*!
* \brief Global schedule container
* For operations and all the operations they depend on.
* The schedule per Operation is named as stage.
*/
class Schedule : public NodeRef {
public:
Schedule() {}
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
*/
explicit Schedule(Array<Operation> ops);
/*!
* \brief Get the stage corresponds to the op
* \param op The operation.
*/
Stage operator[](const Operation& op);
/*!
* \brief Short hand for getting the stage of tensor's operation.
* \param tensor The tensor
* \return The stage corresponding to the tensor's op
*/
Stage operator[](const Tensor& tensor) {
return this->operator[](tensor->op);
}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const ScheduleNode* operator->() const;
};

/*!
Expand Down Expand Up @@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef {
*
* The relations connects the IterVars in the graph.
*/
class ScheduleNode : public Node {
class StageNode : public Node {
public:
/*! \brief The operation to be scheduled */
Operation op;
/*! \brief The thread scope level of the schedule */
/*! \brief The thread scope level of the stage */
std::string scope;
/*! \brief All the nodes in the iter var */
Array<IterVar> all_iter_vars;
Expand All @@ -152,12 +206,10 @@ class ScheduleNode : public Node {
Array<IterVarRelation> relations;
/*! \brief The attachment type of the schedule */
AttachType attach_type{kNone};
/*!
* \brief The attach point of this schedule.
*/
IterVar attach_parent;
/*! \brief the schedules that this schedule depend on */
Array<Schedule> children;
/*! \brief The attach point of this schedule. */
IterVar attach_ivar;
/*! \brief The stage this node attaches to */
Stage attach_stage;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
Expand All @@ -166,8 +218,31 @@ class ScheduleNode : public Node {
v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations);
v->Visit("attach_type", &attach_type);
v->Visit("attach_parent", &attach_parent);
v->Visit("children", &children);
v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage);
}

static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
};

/*! \brief node container for schedule */
class ScheduleNode : public Node {
public:
/*! \brief The root operations */
Array<Operation> roots;
/*!
* \brief list of all stages for non-placeholder ops
* The stage are ordered in PostDFS order of their op.
*/
Array<Stage> stages;
/*! \brief map of operation to the stages */
Map<Operation, Stage> stage_map;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("roots", &roots);
v->Visit("stages", &stages);
v->Visit("stage_map", &stage_map);
}

static constexpr const char* _type_key = "Schedule";
Expand Down Expand Up @@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode {
};

// implementations
inline const StageNode* Stage::operator->() const {
return static_cast<const StageNode*>(node_.get());
}
inline StageNode* Stage::operator->() {
return static_cast<StageNode*>(node_.get());
}

inline const ScheduleNode* Schedule::operator->() const {
return static_cast<const ScheduleNode*>(node_.get());
}
inline ScheduleNode* Schedule::operator->() {
return static_cast<ScheduleNode*>(node_.get());
}

inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get());
Expand Down
13 changes: 11 additions & 2 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,17 @@ def max(expr, rdom):
return x


def Schedule(tensor, scope="global"):
return _function_internal._Schedule(tensor, scope)
def Schedule(ops):
"""Create a schedule for list of ops

Parameters
----------
ops : list of Operations
The source expression.
"""
if not isinstance(ops, (list, _collections.Array)):
ops = [ops]
return _function_internal._Schedule(ops)


_init_function_module("tvm")
Loading