From f07a3273d88ab8a4570278e11f37e28a157dff87 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 10 Jan 2017 23:19:31 -0800 Subject: [PATCH 1/3] [LANG] Change Schedule->Stage, Use Schedule for global schedule --- .travis.yml | 4 +- HalideIR | 2 +- include/tvm/schedule.h | 145 ++++++++++++++---- python/tvm/function.py | 13 +- python/tvm/schedule.py | 63 +++++--- src/c_api/c_api_lang.cc | 49 +++--- src/pass/schedule_ops.cc | 48 ++---- src/schedule/bound.cc | 49 +++--- src/schedule/graph.cc | 23 ++- src/schedule/graph.h | 8 +- .../schedule.cc => schedule/schedule_lang.cc} | 75 ++++++--- tests/python/test_lang_schedule.py | 46 +++--- tests/python/test_pass_schedule_ops.py | 29 ++-- tests/python/test_schedule_bound_inference.py | 43 +++--- tests/travis/run_test.sh | 34 ++-- tests/travis/setup.sh | 13 +- 16 files changed, 379 insertions(+), 265 deletions(-) rename src/{lang/schedule.cc => schedule/schedule_lang.cc} (74%) diff --git a/.travis.yml b/.travis.yml index 266ca62a6b78..2ae5bb0793d1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,9 +8,7 @@ os: env: # code analysis - - TASK=lint - - TASK=cpp_test - - TASK=python_test + - TASK=all_test branches: only: diff --git a/HalideIR b/HalideIR index 1ec478bbd0c2..98e8df564f85 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 1ec478bbd0c20b8659f0c897363b5a76e13ef495 +Subproject commit 98e8df564f8543b337ec0528dbcb06a30f91e694 diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 681c6cd7842d..cbb3cc81c0d3 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -12,6 +12,8 @@ namespace tvm { +// Node container for Stage +class StageNode; // Node container for Schedule class ScheduleNode; // Node container for IterVarRelation @@ -25,46 +27,48 @@ enum AttachType : int { kScope = 3 }; -/*! \brief schedule container */ -class Schedule : public NodeRef { +/*! \brief Stage, contains scheduling for a stage of computation. */ +class Stage : public NodeRef { public: - Schedule() {} - explicit Schedule(std::shared_ptr n) : NodeRef(n) {} + Stage() {} + explicit Stage(std::shared_ptr n) : NodeRef(n) {} /*! * \brief create a new schedule for op. * \param op The operator in the schedule - * \param scope The scope of the schedule */ - Schedule(Operation op, std::string scope); + explicit Stage(Operation op); /*! * \brief access the internal node container * \return the pointer to the internal node container */ - inline const ScheduleNode* operator->() const; + inline const StageNode* operator->() const; /*! * \brief access the internal node container * \return the pointer to the internal node container */ - inline ScheduleNode* operator->(); + inline StageNode* operator->(); + /*! + * \brief set the memory scope of the stage + * \param scope The memory scope. + */ + Stage& set_scope(std::string scope); // NOLINT(*) /*! * \brief specify the schedule to be computed at the parent schedule's scope. * \param parent The parent schedule. * \param scope The iteration point to carry the schedule. * \return reference to self. */ - Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*) + Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline, attach it at parent. - * \param parent The parent schedule to be attached to. * \return reference to self. */ - Schedule& compute_inline(Schedule parent); // NOLINT(*) + Stage& compute_inline(); // NOLINT(*) /*! * \brief Compute the function at root, attach it to its parent. - * \param parent The parent schedule to be attached to. * \return reference to self. */ - Schedule& compute_root(Schedule parent); // NOLINT(*) + Stage& compute_root(); // NOLINT(*) /*! * \brief Split the parent by factor, generate * \param parent The parent iteration domain. @@ -73,7 +77,7 @@ class Schedule : public NodeRef { * \param factor The split factor of the loop. * \return reference to self. */ - Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) + Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) /*! * \brief Split the iteration with a given outer domain, * the outer domain must have a thread-tag. @@ -85,7 +89,7 @@ class Schedule : public NodeRef { * factor must be provided such that factor * outer.extent >= parent.extent. * \return reference to self. */ - Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) + Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param inner The inner domain to be fused @@ -93,16 +97,66 @@ class Schedule : public NodeRef { * \param p_target The result target domain. * \return reference to self. */ - Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) + Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) /*! * \brief Reorder the iteration * \param order The order of iteration variable. * \return reference to self. */ - Schedule& reorder(const Array& order); // NOLINT(*) - Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer, - IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner, - Expr x_factor, Expr y_factor); // NOLINT(*) + Stage& reorder(const Array& order); // NOLINT(*) + /*! + * \brief Perform tiling on two dimensions + * The final loop order from outmost to inner most are + * [x_outer, y_outer, x_inner, y_inner] + * + * \param x_parent The original x dimension + * \param y_parent The original y dimension + * \param p_x_outer Outer axis of x dimension + * \param p_y_outer Outer axis of y dimension + * \param p_x_inner Inner axis of x dimension + * \param p_y_inner Inner axis of y dimension + * \param x_factor The stride factor on x axis + * \param y_factor The stride factor on y axis + * \return reference to self. + */ + Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) + IterVar* p_x_outer, IterVar* p_y_outer, + IterVar* p_x_inner, IterVar* p_y_inner, + Expr x_factor, Expr y_factor); +}; + +/*! + * \brief Global schedule container + * For operations and all the operations they depend on. + * The schedule per Operation is named as stage. + */ +class Schedule : public NodeRef { + public: + Schedule() {} + explicit Schedule(std::shared_ptr n) : NodeRef(n) {} + /*! + * \brief construct schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + */ + explicit Schedule(Array ops); + /*! + * \brief Get the stage corresponds to the op + * \param op The operation. + */ + Stage operator[](const Operation& op); + /*! + * \brief Short hand for getting the stage of tensor's operation. + * \param tensor The tensor + * \return The stage corresponding to the tensor's op + */ + Stage operator[](const Tensor& tensor) { + return this->operator[](tensor->op); + } + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const ScheduleNode* operator->() const; }; /*! @@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef { * * The relations connects the IterVars in the graph. */ -class ScheduleNode : public Node { +class StageNode : public Node { public: /*! \brief The operation to be scheduled */ Operation op; - /*! \brief The thread scope level of the schedule */ + /*! \brief The thread scope level of the stage */ std::string scope; /*! \brief All the nodes in the iter var */ Array all_iter_vars; @@ -152,12 +206,10 @@ class ScheduleNode : public Node { Array relations; /*! \brief The attachment type of the schedule */ AttachType attach_type{kNone}; - /*! - * \brief The attach point of this schedule. - */ - IterVar attach_parent; - /*! \brief the schedules that this schedule depend on */ - Array children; + /*! \brief The attach point of this schedule. */ + IterVar attach_ivar; + /*! \brief The stage this node attaches to */ + Stage attach_stage; void VisitAttrs(AttrVisitor* v) final { v->Visit("scope", &scope); @@ -166,8 +218,31 @@ class ScheduleNode : public Node { v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("relations", &relations); v->Visit("attach_type", &attach_type); - v->Visit("attach_parent", &attach_parent); - v->Visit("children", &children); + v->Visit("attach_ivar", &attach_ivar); + v->Visit("attach_stage", &attach_stage); + } + + static constexpr const char* _type_key = "Stage"; + TVM_DECLARE_NODE_TYPE_INFO(StageNode); +}; + +/*! \brief node container for schedule */ +class ScheduleNode : public Node { + public: + /*! \brief The root operations */ + Array roots; + /*! + * \brief list of all stages for non-placeholder ops + * The stage are ordered in PostDFS order of their op. + */ + Array stages; + /*! \brief map of operation to the stages */ + Map stage_map; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("roots", &roots); + v->Visit("stages", &stages); + v->Visit("stage_map", &stage_map); } static constexpr const char* _type_key = "Schedule"; @@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode { }; // implementations +inline const StageNode* Stage::operator->() const { + return static_cast(node_.get()); +} +inline StageNode* Stage::operator->() { + return static_cast(node_.get()); +} + inline const ScheduleNode* Schedule::operator->() const { return static_cast(node_.get()); } -inline ScheduleNode* Schedule::operator->() { - return static_cast(node_.get()); -} inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(node_.get()); diff --git a/python/tvm/function.py b/python/tvm/function.py index 3dde071a6b82..22bfb3555fa1 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -174,8 +174,17 @@ def max(expr, rdom): return x -def Schedule(tensor, scope="global"): - return _function_internal._Schedule(tensor, scope) +def Schedule(ops): + """Create a schedule for list of ops + + Parameters + ---------- + ops : list of Operations + The source expression. + """ + if not isinstance(ops, (list, _collections.Array)): + ops = [ops] + return _function_internal._Schedule(ops) _init_function_module("tvm") diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 02f73660f0e3..7a5282b1219f 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -2,6 +2,7 @@ from __future__ import absolute_import as _abs from ._ctypes._api import NodeBase, register_node from . import _function_internal +from . import tensor as _tensor @register_node class Split(NodeBase): @@ -11,10 +12,22 @@ class Split(NodeBase): class Fuse(NodeBase): pass + @register_node class Schedule(NodeBase): + def __getitem__(self, k): + if isinstance(k, _tensor.Tensor): + k = k.op + if not isinstance(k, _tensor.Operation): + raise ValueError("Expect schedule key to be Tensor or Operation") + if not k in self.stage_map: + raise ValueError("Cannot find the operation %s in schedule" % (str(k))) + return self.stage_map[k] + +@register_node +class Stage(NodeBase): def split(self, parent, factor=None, outer=None): - """Split the schedule either by factor providing outer scope, or both + """Split the stage either by factor providing outer scope, or both Parameters ---------- @@ -40,11 +53,11 @@ def split(self, parent, factor=None, outer=None): raise ValueError("split by outer must have special thread_tag") if outer.dom is None: raise ValueError("split by outer must have specified domain") - inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor) + inner = _function_internal._StageSplitByOuter(self, parent, outer, factor) else: if factor is None: raise ValueError("either outer or factor need to be provided") - outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor) + outer, inner = _function_internal._StageSplitByFactor(self, parent, factor) return outer, inner def fuse(self, inner, outer): @@ -63,40 +76,50 @@ def fuse(self, inner, outer): inner : IterVar The fused variable of iteration. """ - return _function_internal._ScheduleFuse(self, inner, outer) + return _function_internal._StageFuse(self, inner, outer) + + def set_scope(self, scope): + """Set the thread scope of this stage + + Parameters + ---------- + scope : str + The thread scope of this stage + """ + return _function_internal._StageSetScope(self, scope) def compute_at(self, parent, scope): - """Attach the schedule at parent's scope + """Attach the stage at parent's scope Parameters ---------- - parent : Schedule - The parent schedule + parent : Stage + The parent stage scope : IterVar The loop scope t be attached to. """ - _function_internal._ScheduleComputeAt(self, parent, scope) + _function_internal._StageComputeAt(self, parent, scope) - def compute_inline(self, parent): - """Attach the schedule at parent, and mark it as inline + def compute_inline(self): + """Mark stage as inline Parameters ---------- - parent : Schedule - The parent schedule + parent : Stage + The parent stage """ - _function_internal._ScheduleComputeInline(self, parent) + _function_internal._StageComputeInline(self) - def compute_root(self, parent): - """Attach the schedule at parent, and mark it as root + def compute_root(self): + """Attach the stage at parent, and mark it as root Parameters ---------- - parent : Schedule - The parent schedule + parent : Stage + The parent stage """ - _function_internal._ScheduleComputeInline(self, parent) + _function_internal._StageComputeInline(self) def reorder(self, *args): """reorder the arguments in the specified order. @@ -106,9 +129,9 @@ def reorder(self, *args): args : list of IterVar The order to be ordered """ - _function_internal._ScheduleReorder(self, args) + _function_internal._StageReorder(self, args) def tile(self, x_parent, y_parent, x_factor, y_factor): - x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile( + x_outer, y_outer, x_inner, y_inner = _function_internal._StageTile( self, x_parent, y_parent, x_factor, y_factor) return x_outer, y_outer, x_inner, y_inner diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index 46075f1140d4..3f2b4e2a0abd 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -176,7 +176,7 @@ TVM_REGISTER_API(_ComputeOp) TVM_REGISTER_API(_OpGetOutput) .set_body([](const ArgStack& args, RetValue *ret) { *ret = args.at(0).operator Operation().output( - args.at(1).operator size_t()); + args.at(1).operator int64_t()); }); @@ -185,64 +185,69 @@ TVM_REGISTER_API(_IterVar) *ret = IterVar(args.at(0), args.at(1), args.at(2)); }); - TVM_REGISTER_API(_Schedule) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Schedule(args.at(0), args.at(1)); +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = Schedule(args.at(0).operator Array()); + }); + +TVM_REGISTER_API(_StageSetScope) +.set_body([](const ArgStack& args, RetValue *ret) { + args.at(0).operator Stage() + .set_scope(args.at(1)); }); -TVM_REGISTER_API(_ScheduleSplitByFactor) +TVM_REGISTER_API(_StageSplitByFactor) .set_body([](const ArgStack& args, RetValue *ret) { IterVar outer, inner; - args.at(0).operator Schedule() + args.at(0).operator Stage() .split(args.at(1), &outer, &inner, args.at(2)); *ret = Array({outer, inner}); }); -TVM_REGISTER_API(_ScheduleSplitByOuter) +TVM_REGISTER_API(_StageSplitByOuter) .set_body([](const ArgStack& args, RetValue *ret) { IterVar inner; - args.at(0).operator Schedule() + args.at(0).operator Stage() .split(args.at(1), args.at(2), &inner, args.at(3)); *ret = inner; }); -TVM_REGISTER_API(_ScheduleFuse) +TVM_REGISTER_API(_StageFuse) .set_body([](const ArgStack& args, RetValue *ret) { IterVar fused; - args.at(0).operator Schedule() + args.at(0).operator Stage() .split(args.at(1), args.at(2), &fused); *ret = fused; }); -TVM_REGISTER_API(_ScheduleComputeAt) +TVM_REGISTER_API(_StageComputeAt) .set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Schedule() + args.at(0).operator Stage() .compute_at(args.at(1), args.at(2)); }); -TVM_REGISTER_API(_ScheduleComputeInline) +TVM_REGISTER_API(_StageComputeInline) .set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Schedule() - .compute_inline(args.at(1)); + args.at(0).operator Stage() + .compute_inline(); }); -TVM_REGISTER_API(_ScheduleComputeRoot) +TVM_REGISTER_API(_StageComputeRoot) .set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Schedule() - .compute_root(args.at(1)); + args.at(0).operator Stage() + .compute_root(); }); -TVM_REGISTER_API(_ScheduleReorder) +TVM_REGISTER_API(_StageReorder) .set_body([](const ArgStack& args, RetValue *ret) { - args.at(0).operator Schedule() + args.at(0).operator Stage() .reorder(args.at(1)); }); -TVM_REGISTER_API(_ScheduleTile) +TVM_REGISTER_API(_StageTile) .set_body([](const ArgStack& args, RetValue *ret) { IterVar x_outer, y_outer, x_inner, y_inner; - args.at(0).operator Schedule() + args.at(0).operator Stage() .tile(args.at(1), args.at(2), &x_outer, &y_outer, &x_inner, &y_inner, args.at(3), args.at(4)); *ret = Array({x_outer, y_outer, x_inner, y_inner}); diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index ae3b96f27be5..3cdb8f171a7f 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -22,7 +22,7 @@ namespace { * \param p_state The message passing state * IterVar->The assignment. */ -void PassUpOffset(const Schedule& s, +void PassUpOffset(const Stage& s, const Map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; @@ -130,7 +130,7 @@ Stmt MergeNest(std::vector > nest, Stmt body) { * The flattened Stmt are ordered from outmost to inner most order. */ std::vector > MakeLoopNest( - const Schedule& sch, + const Stage& sch, const Map& dom_map) { // optional, use let to define some CSE in dom_map. auto leaf_iter_vars = sch->leaf_iter_vars; @@ -244,7 +244,7 @@ Stmt MakeRealize(const ComputeOpNode* op, bounds, make_const(Bool(1), true), body); } -Stmt MakePipeline(const Schedule& sch, +Stmt MakePipeline(const Stage& sch, const Map& dom_map, Stmt consumer) { std::vector tensors; @@ -280,7 +280,7 @@ Stmt MakePipeline(const Schedule& sch, // inject the operator's realization on the stmt. class InjectRealize : public IRMutator { public: - InjectRealize(Schedule schedule, Map dom_map) + InjectRealize(Stage schedule, Map dom_map) : schedule(schedule), dom_map(dom_map) {} Stmt Mutate(Stmt stmt) final { @@ -289,7 +289,7 @@ class InjectRealize : public IRMutator { const AttrStmt* op = stmt.as(); if (op != nullptr && op->type_key == "scope") { - if (op->node == schedule->attach_parent) { + if (op->node == schedule->attach_ivar) { CHECK(!found_attach); found_attach = true; stmt = AttrStmt::make( @@ -301,41 +301,13 @@ class InjectRealize : public IRMutator { return stmt; } // the operations to be carried - Schedule schedule; + Stage schedule; // domain map Map dom_map; // whether attach point is found bool found_attach{false}; }; - - -void GetOpToScheduleMap( - Schedule s, - std::unordered_map* ret) { - CHECK(!ret->count(s->op)) - << "Duplicated schedule for op"; - (*ret)[s->op] = s; - for (Schedule c : s->children) { - GetOpToScheduleMap(c, ret); - } -} - -// order schedule by DFS calling order of ops -std::vector OrderSchedule(Schedule s) { - auto g = schedule::CreateReadGraph(s->op); - auto post_order = schedule::PostDFSOrder(s->op, g); - std::unordered_map op2sch; - GetOpToScheduleMap(s, &op2sch); - std::vector sorder; - - // reverse iteration. - for (size_t i = post_order.size(); i != 0; --i) { - sorder.push_back(op2sch.at(post_order[i - 1])); - } - return sorder; -} - Stmt InjectInline(const Operation op, Stmt body) { CHECK(body.defined()); const ComputeOpNode* compute = op.as(); @@ -351,11 +323,11 @@ Stmt InjectInline(const Operation op, Stmt body) { } // namespace Stmt ScheduleOps( - Schedule s, Map dom_map) { - std::vector svec = OrderSchedule(s); + Schedule sch, Map dom_map) { Stmt body = Stmt(); - - for (Schedule s : svec) { + // reverse the post DFS order. + for (size_t i = sch->stages.size(); i != 0; --i) { + Stage s = sch->stages[i - 1]; if (s->attach_type == kInline) { body = InjectInline(s->op, body); } else if (s->attach_type == kRoot || s-> attach_type == kNone) { diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 86ed36fb91f9..9ac720305590 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -17,10 +17,10 @@ inline Expr DivCeil(Expr a, Expr b) { return (a + b - 1) / b; } -// Downward message passing algorithm on schedule s, +// Downward message passing algorithm on stage schedule s, // pass the range state down from the root to the leaves -// after this pass, every IterVar in the schedule hyper graph will have a range(domain) -void PassDown(const Schedule& s, +// after this pass, every IterVar in the stage hyper graph will have a range(domain) +void PassDown(const Stage& s, std::unordered_map* p_state) { auto& state = *p_state; // forwar iteration on relations @@ -63,7 +63,7 @@ void PassDown(const Schedule& s, // pass the integer set on each leave loop up to the root // dom_map is the result of PassDown, it records the domain of each IterVar. // dom_map can be used to get cached result in reverse construction. -void PassUp(const ScheduleNode* s, +void PassUp(const Stage& s, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; @@ -180,13 +180,11 @@ bool ScopeRelax(const IterVar& iv, const std::string& scope) { return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag); } -void InferBound( - const ScheduleNode* parent, - const Schedule& sch, - std::unordered_map* rmap) { - if (sch->attach_type == kInline) return; - if (sch->attach_type == kRoot || sch->attach_type == kNone) { - auto root_iter_vars = sch->op->root_iter_vars(); +void InferBound(const Stage& stage, + std::unordered_map* rmap) { + if (stage->attach_type == kInline) return; + if (stage->attach_type == kRoot || stage->attach_type == kNone) { + auto root_iter_vars = stage->op->root_iter_vars(); for (auto iv : root_iter_vars) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); @@ -194,22 +192,23 @@ void InferBound( } } // get range of all child iter vars. - PassDown(sch, rmap); + PassDown(stage, rmap); - if (sch->attach_type == kScope) { - CHECK(parent != nullptr); - auto g = CreateReadGraph(parent->op); - auto post_order = PostDFSOrder(parent->op, g); + if (stage->attach_type == kScope) { + Stage parent = stage->attach_stage; + CHECK(parent.defined()); + auto g = CreateReadGraph({parent->op}); + auto post_order = PostDFSOrder({parent->op}, g); std::unordered_map up_state; bool fix_value = true; for (auto iv : parent->leaf_iter_vars) { - if (fix_value && !ScopeRelax(iv, sch->scope)) { + if (fix_value && !ScopeRelax(iv, stage->scope)) { up_state[iv] = IntSet::make_point(iv->var); } else { up_state[iv] = IntSet::make_range(rmap->at(iv)); } - if (sch->attach_parent == iv) { + if (stage->attach_ivar == iv) { fix_value = false; } } @@ -221,24 +220,22 @@ void InferBound( bp_state[iv] = {up_state.at(iv)}; } auto result = BoundProp(post_order, &bp_state); - for (auto iv : sch->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { CHECK(result.count(iv)); CHECK(!rmap->count(iv)); (*rmap)[iv] = result.at(iv).GetCoverRange(); } } - // also call infer bound on children - for (Schedule child : sch->children) { - InferBound(sch.operator->(), child, rmap); - } } Map InferBound(Schedule sch) { std::unordered_map ret; - CHECK(sch->attach_type != kInline && sch->attach_type != kScope) - << "the Schedule is not a root Schedule"; - InferBound(nullptr, sch, &ret); + // reverse post DFS order, from out most stage to the innermost + for (size_t i = sch->stages.size(); i != 0; --i) { + Stage stage = sch->stages[i - 1]; + InferBound(stage, &ret); + } return Map(ret.begin(), ret.end()); } diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index ade0e433d20f..5b13c1569078 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -14,10 +14,15 @@ namespace schedule { // construct a read graph that gives readers of each operation // that the root depend on -ReadGraph CreateReadGraph(const Operation& root) { +ReadGraph CreateReadGraph(const Array& roots) { ReadGraph rmap; - std::vector stack{root}; - std::unordered_set visited{root.get()}; + std::vector stack; + std::unordered_set visited; + // initialize the roots + for (Operation op : roots) { + stack.push_back(op); + visited.insert(op.get()); + } while (!stack.empty()) { Operation op = stack.back(); @@ -51,20 +56,22 @@ void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, Array* post_order) { + if (op.as() || visited->count(op)) return; visited->insert(op); for (const auto& t : g.at(op)) { - if (!t->op.as() && !visited->count(t->op)) { - PostDFSOrder(t->op, g, visited, post_order); - } + PostDFSOrder(t->op, g, visited, post_order); } post_order->push_back(op); } Array PostDFSOrder( - const Operation& root, const ReadGraph& g) { + const Array& roots, + const ReadGraph& g) { std::unordered_set visited; Array post_order; - PostDFSOrder(root, g, &visited, &post_order); + for (Operation op : roots) { + PostDFSOrder(op, g, &visited, &post_order); + } return post_order; } diff --git a/src/schedule/graph.h b/src/schedule/graph.h index 53e84ca99a34..5a40c8e4ce0f 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -24,14 +24,14 @@ using ReadGraph = Map >; * Tensors that it directly depends on. * * The result map contains Operations needed to finish root Operation. - * \param root The root operation. + * \param roots The root operation. * \return The result map. */ -ReadGraph CreateReadGraph(const Operation& root); +ReadGraph CreateReadGraph(const Array& roots); /*! * \brief Get a post DFS ordered of operations in the graph. - * \param root The root of the graph. + * \param roots The root of the graph. * \param g The read graph. * \return vector order of Operations in PostDFS order. * @@ -39,7 +39,7 @@ ReadGraph CreateReadGraph(const Operation& root); * and can be used when topoligical order is needed. */ Array PostDFSOrder( - const Operation& root, const ReadGraph& g); + const Array& roots, const ReadGraph& g); } // namespace schedule } // namespace tvm diff --git a/src/lang/schedule.cc b/src/schedule/schedule_lang.cc similarity index 74% rename from src/lang/schedule.cc rename to src/schedule/schedule_lang.cc index b5d4429eb06e..22b3d10d8b9b 100644 --- a/src/lang/schedule.cc +++ b/src/schedule/schedule_lang.cc @@ -3,6 +3,7 @@ * \file schedule.cc */ #include +#include "./graph.h" namespace tvm { @@ -31,7 +32,7 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) return 0; } -void Split(ScheduleNode* self, IterVar parent, +void Split(StageNode* self, IterVar parent, IterVar outer, IterVar inner, Expr factor) { ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -49,19 +50,30 @@ void Split(ScheduleNode* self, IterVar parent, } // namespace -Schedule::Schedule(Operation op, std::string scope) { - auto n = std::make_shared(); +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const StageNode *op, IRPrinter *p) { + p->stream << "stage(" + << op->op + << ")"; +}); + +Stage::Stage(Operation op) { + auto n = std::make_shared(); n->op = op; - n->scope = scope; n->all_iter_vars = op->root_iter_vars(); n->leaf_iter_vars = op->root_iter_vars(); node_ = n; } -Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) - CHECK_EQ((*this)->attach_type, kNone); +Stage& Stage::set_scope(std::string scope) { // NOLINT(*) + (*this)->scope = scope; + return *this; +} + +Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) (*this)->attach_type = kScope; - (*this)->attach_parent = scope; + (*this)->attach_ivar = scope; + (*this)->attach_stage = parent; bool found = false; for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { if (scope == parent->leaf_iter_vars[i]) { @@ -70,25 +82,20 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) } CHECK(found) << "Cannot compute at a iteration variable that is not part of parent leaf vars"; - parent->children.push_back(*this); return *this; } -Schedule& Schedule::compute_inline(Schedule parent) { // NOLINT(*) - CHECK_EQ((*this)->attach_type, kNone); +Stage& Stage::compute_inline() { // NOLINT(*) (*this)->attach_type = kInline; - parent->children.push_back(*this); return *this; } -Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*) - CHECK_EQ((*this)->attach_type, kNone); +Stage& Stage::compute_root() { // NOLINT(*) (*this)->attach_type = kRoot; - parent->children.push_back(*this); return *this; } -Schedule& Schedule::split( +Stage& Stage::split( IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) // place holder for the splitted results. IterVar outer(Range(), parent->var->name_hint + ".outer"); @@ -99,7 +106,7 @@ Schedule& Schedule::split( return *this; } -Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) +Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) // place holder for the splitted results. IterVar inner(Range(), parent->var->name_hint + ".inner"); *p_inner = inner; @@ -108,9 +115,9 @@ Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr return *this; } -Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) +Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); - ScheduleNode* self = operator->(); + StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -128,8 +135,8 @@ Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // return *this; } -Schedule& Schedule::reorder(const Array& order) { // NOLINT(*) - ScheduleNode* self = operator->(); +Stage& Stage::reorder(const Array& order) { // NOLINT(*) + StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); std::vector pos; @@ -148,16 +155,34 @@ Schedule& Schedule::reorder(const Array& order) { // NOLINT(*) return *this; } -Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner, - Expr x_factor, Expr y_factor) { // NOLINT(*) +Stage& Stage::tile(IterVar x_parent, IterVar y_parent, + IterVar* p_x_outer, IterVar* p_y_outer, + IterVar* p_x_inner, IterVar* p_y_inner, + Expr x_factor, Expr y_factor) { // NOLINT(*) split(x_parent, p_x_outer, p_x_inner, x_factor); split(y_parent, p_y_outer, p_y_inner, y_factor); reorder(Array({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); return *this; } + +Schedule::Schedule(Array ops) { + auto n = std::make_shared(); + n->roots = ops; + auto g = schedule::CreateReadGraph(n->roots); + Array post_order = schedule::PostDFSOrder(n->roots, g); + for (Operation op : post_order) { + Stage stage(op); + n->stages.push_back(stage); + n->stage_map.Set(op, stage); + } + node_ = std::move(n); +} + +Stage Schedule::operator[](const Operation& op) { + return (*this)->stage_map.at(op); +} + IterVarRelation SplitNode::make( IterVar parent, IterVar outer, IterVar inner, Expr factor) { @@ -178,7 +203,7 @@ IterVarRelation FuseNode::make( return IterVarRelation(n); } -TVM_REGISTER_NODE_TYPE(ScheduleNode); +TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(FuseNode); diff --git a/tests/python/test_lang_schedule.py b/tests/python/test_lang_schedule.py index 221c998d3648..7e3c2f3ce64d 100644 --- a/tests/python/test_lang_schedule.py +++ b/tests/python/test_lang_schedule.py @@ -7,41 +7,37 @@ def test_schedule_create(): A = tvm.placeholder((m, l), name='A') B = tvm.placeholder((n, l), name='B') AA = tvm.compute((m, l), lambda i, j: A[i, j]) - T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) - - sch_T = tvm.Schedule(T.op, scope="shared") - sch_A = tvm.Schedule(AA.op, scope="global") - - xo, xi = sch_T.split(T.op.axis[0], factor=10) - xi1, xi2 = sch_T.split(xi, factor=2) - - sch_A.compute_at(sch_T, xi1) - xo, xi = sch_A.split(AA.op.axis[0], factor=10) - - sch_T.reorder(xi2, xi1) - assert T.op.axis[1] in sch_T.leaf_iter_vars + T = tvm.compute((m, n, l), lambda i, j, k: AA(i, k) * B(j, k)) + s = tvm.Schedule(T.op) + s[AA].set_scope("shared") + xo, xi = s[T].split(T.op.axis[0], factor=10) + xi1, xi2 = s[T].split(xi, factor=2) + s[AA].compute_at(s[T], xi1) + xo, xi = s[AA].split(AA.op.axis[0], factor=10) + s[T].reorder(xi2, xi1) + assert T.op.axis[1] in s[T].leaf_iter_vars def test_reorder(): m = tvm.Var('m') A = tvm.placeholder((m,), name='A') T = tvm.compute(m, lambda i: A[i+1]) - sch_T = tvm.Schedule(T.op, scope="shared") - xo, xi = sch_T.split(T.op.axis[0], factor=10) - xi1, xi2 = sch_T.split(xi, factor=2) + s = tvm.Schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=10) + xi1, xi2 = s[T].split(xi, factor=2) order = (xi2, xi1, xo) - assert tuple(sch_T.leaf_iter_vars) != order - sch_T.reorder(*order) - assert tuple(sch_T.leaf_iter_vars) == order + assert tuple(s[T].leaf_iter_vars) != order + s[T].reorder(*order) + assert tuple(s[T].leaf_iter_vars) == order def test_split(): m = tvm.Var('m') A = tvm.placeholder((m,), name='A') T = tvm.compute((m,), lambda i: A[i]) - sT = tvm.Schedule(T.op) - xo, xi = sT.split(T.op.axis[0], factor=10) - assert tuple(sT.leaf_iter_vars) == (xo, xi) + s = tvm.Schedule(T.op) + xo, xi = s[T].split(T.op.axis[0], factor=10) + assert tuple(s[T].leaf_iter_vars) == (xo, xi) def test_tile(): @@ -50,9 +46,9 @@ def test_tile(): A = tvm.placeholder((m, n), name='A') T = tvm.compute((m, n), lambda i, j: A[i, j]) - sch_T = tvm.Schedule(T.op, scope="shared") - xo, yo, xi, yi = sch_T.tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) - assert tuple(sch_T.leaf_iter_vars) == (xo, yo, xi, yi) + s = tvm.Schedule(T.op) + xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) + assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi) if __name__ == "__main__": test_schedule_create() diff --git a/tests/python/test_pass_schedule_ops.py b/tests/python/test_pass_schedule_ops.py index 3a74bd7f53bf..e634d0773b0c 100644 --- a/tests/python/test_pass_schedule_ops.py +++ b/tests/python/test_pass_schedule_ops.py @@ -6,10 +6,12 @@ def test_schedule0(): l = tvm.Var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') - sA1 = tvm.Schedule(A1.op) - bounds = tvm.schedule.InferBound(sA1) + + s = tvm.Schedule(A1.op) + + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(sA1, bounds) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) print(stmt) def test_schedule1(): @@ -17,11 +19,12 @@ def test_schedule1(): l = tvm.Var('l') A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') - sA1 = tvm.Schedule(A1.op) - xo, xi = sA1.split(A1.op.axis[0], 8) - bounds = tvm.schedule.InferBound(sA1) + + s = tvm.Schedule(A1.op) + xo, xi = s[A1].split(A1.op.axis[0], 8) + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(sA1, bounds) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) print(stmt) def test_schedule2(): @@ -30,13 +33,13 @@ def test_schedule2(): A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - sA1 = tvm.Schedule(A1.op) - sA2 = tvm.Schedule(A2.op) - xo, xi = sA2.split(A2.op.axis[0], 8) - sA1.compute_at(sA2, xo) - bounds = tvm.schedule.InferBound(sA2) + + s = tvm.Schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], 8) + s[A1].compute_at(s[A2], xo) + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) - stmt = tvm.ir_pass.ScheduleOps(sA2, bounds) + stmt = tvm.ir_pass.ScheduleOps(s, bounds) print(stmt) diff --git a/tests/python/test_schedule_bound_inference.py b/tests/python/test_schedule_bound_inference.py index 7970d99080b7..9e8c70cac66b 100644 --- a/tests/python/test_schedule_bound_inference.py +++ b/tests/python/test_schedule_bound_inference.py @@ -6,11 +6,11 @@ def test_bound1(): A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - sA1 = tvm.Schedule(A1.op) - sA2 = tvm.Schedule(A2.op) - xo, xi = sA2.split(A2.op.axis[0], 8) - sA1.compute_at(sA2, xo) - bounds = tvm.schedule.InferBound(sA2) + + s = tvm.Schedule([A2.op]) + xo, xi = s[A2].split(s[A2].op.axis[0], 8) + s[A1].compute_at(s[A2], xo) + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) assert(bounds[A1.op.axis[0]].extent.value == 8) @@ -20,11 +20,10 @@ def test_bound2(): A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - sA1 = tvm.Schedule(A1.op) - sA2 = tvm.Schedule(A2.op) - xo, yo, xi, yi = sA2.tile(A2.op.axis[0], A2.op.axis[1], 8, 8) - sA1.compute_at(sA2, yo) - bounds = tvm.schedule.InferBound(sA2) + s = tvm.Schedule(A2.op) + xo, yo, xi, yi = s[A2].tile(A2.op.axis[0], A2.op.axis[1], 8, 8) + s[A1].compute_at(s[A2], yo) + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) assert(bounds[A1.op.axis[0]].extent.value == 8) assert(bounds[A1.op.axis[1]].extent.value == 8) @@ -35,16 +34,18 @@ def test_bound3(): A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - sA1 = tvm.Schedule(A1.op, scope="shared") - sA2 = tvm.Schedule(A2.op) + + s = tvm.Schedule(A2.op) + + s[A1].set_scope("shared") thread_x = tvm.IterVar((0, 16), thread_tag="threadIdx.x") - xo, xi = sA2.split(A2.op.axis[0], 32) - xi0, xi1 = sA2.split(xi, outer=thread_x) - yo, yi = sA2.split(A2.op.axis[1], 16) - sA2.reorder(xo, xi0, yo, xi1, yi) - sA1.compute_at(sA2, yo) + xo, xi = s[A2].split(A2.op.axis[0], 32) + xi0, xi1 = s[A2].split(xi, outer=thread_x) + yo, yi = s[A2].split(A2.op.axis[1], 16) + s[A2].reorder(xo, xi0, yo, xi1, yi) + s[A1].compute_at(s[A2], yo) - bounds = tvm.schedule.InferBound(sA2) + bounds = tvm.schedule.InferBound(s) assert isinstance(bounds, tvm.collections.Map) assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) @@ -56,10 +57,12 @@ def test_create_read_graph(): A = tvm.placeholder((m, l), name='A') A1 = tvm.compute((m, l), lambda i, j: A[i, j]) A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) - g = tvm.schedule.CreateReadGraph(A2.op) + + g = tvm.schedule.CreateReadGraph([A2.op]) + assert g[A2.op][0] == A1 assert g[A1.op][0] == A - post_order = tvm.schedule.PostDFSOrder(A2.op, g) + post_order = tvm.schedule.PostDFSOrder([A2.op], g) assert(post_order[0] == A1.op) assert(post_order[1] == A2.op) diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 039a1abc0033..d682e2b0fb94 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -1,17 +1,17 @@ #!/bin/bash - -if [ ${TASK} == "lint" ]; then - make lint || exit -1 - echo "Check documentations of c++ code..." - make doc 2>log.txt - (cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt - echo "---------Error Log----------" - cat logclean.txt - echo "----------------------------" - (cat logclean.txt|grep warning) && exit -1 - (cat logclean.txt|grep error) && exit -1 - exit 0 +if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then + if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then + make lint || exit -1 + echo "Check documentations of c++ code..." + make doc 2>log.txt + (cat log.txt| grep -v ENABLE_PREPROCESSING |grep -v "unsupported tag") > logclean.txt + echo "---------Error Log----------" + cat logclean.txt + echo "----------------------------" + (cat logclean.txt|grep warning) && exit -1 + (cat logclean.txt|grep error) && exit -1 + fi fi @@ -22,19 +22,16 @@ if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then fi fi -if [ ${TASK} == "cpp_test" ]; then +if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then make -f dmlc-core/scripts/packages.mk gtest make test || exit -1 for test in tests/cpp/*_test; do ./$test || exit -1 done - exit 0 fi -# run two test one for cython, one for ctypes -if [ ${TASK} == "python_test" ]; then - make clean - make -j all || exit -1 +if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then + make all || exit -1 if [ ${TRAVIS_OS_NAME} == "osx" ]; then python -m nose tests/python/ || exit -1 python3 -m nose tests/python/ || exit -1 @@ -42,5 +39,4 @@ if [ ${TASK} == "python_test" ]; then nosetests tests/python/ || exit -1 nosetests3 tests/python/ || exit -1 fi - exit 0 fi diff --git a/tests/travis/setup.sh b/tests/travis/setup.sh index 2e6545a50d3f..ff5d48d953af 100755 --- a/tests/travis/setup.sh +++ b/tests/travis/setup.sh @@ -1,15 +1,16 @@ #!/bin/bash - if [ ${TRAVIS_OS_NAME} == "osx" ]; then - brew update - brew install python3 - if [ ${TASK} == "python_test" ]; then + if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then + brew update + brew install python3 python -m pip install --user nose python3 -m pip install --user nose fi fi -if [ ${TASK} == "lint" ]; then - pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6' +if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then + if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then + pip install --user cpplint 'pylint==1.4.4' 'astroid==1.3.6' + fi fi From e978a3b9cfc31b2088041248ef20e25ac0dd5ec7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 10 Jan 2017 23:35:26 -0800 Subject: [PATCH 2/3] add numpy as dep --- tests/travis/setup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/travis/setup.sh b/tests/travis/setup.sh index ff5d48d953af..1635fbc572a0 100755 --- a/tests/travis/setup.sh +++ b/tests/travis/setup.sh @@ -4,8 +4,8 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then brew update brew install python3 - python -m pip install --user nose - python3 -m pip install --user nose + python -m pip install --user nose numpy + python3 -m pip install --user nose numpy fi fi From 572899bdad7f0f9bd70b08fc1da7643c720dd0b7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 11 Jan 2017 09:36:15 -0800 Subject: [PATCH 3/3] add numpy installation, temporary disable osx --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 2ae5bb0793d1..31d6e49f3dd1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ language: cpp os: - linux - - osx + # - osx env: # code analysis @@ -33,6 +33,7 @@ addons: - g++-4.8 - python-numpy - python-nose + - python3-numpy - python3-dev - python3-nose - graphviz