Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG/PASS] Virtual thread support #38

Merged
merged 1 commit into from
Feb 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};

/*! \brief namespace of possible attribute sin AttrStmt.type_key */
namespace attr {
/*!
* \brief Mark scope of iteration variable, used by Schedule.
*/
constexpr const char* scope = "scope";
/*!
* \brief Mark launching extent of thread, used by device API.
*/
constexpr const char* thread_extent = "thread_extent";
/*!
* \brief Mark launching of a virtual thread.
*/
constexpr const char* virtual_thread = "virtual_thread";
/*!
* \brief Mark storage scope of buffers
*/
constexpr const char* storage_scope = "storage_scope";
/*!
* \brief Mark storage scope of realizations
*/
constexpr const char* realize_scope = "realize_scope";
} // namespace attr

/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir_mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class IRMutator {
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Stmt Mutate_(const Block* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Stmt Inline(Stmt stmt,
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
Expand All @@ -108,15 +109,34 @@ Stmt StorageFlatten(Stmt stmt,
* \brief unroll the constant loops
* \param stmt The statment to be unrolled.
* \param max_auto_step The maximum step to stop performing automatic unrolling.
* \return Transformed stmt.
*/
Stmt UnrollLoop(Stmt stmt, int max_auto_step);

/*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
* \return Transformed stmt.
*/
Stmt VectorizeLoop(Stmt stmt);

/*!
* \brief Inject virtual thread loops into stmt.
* \param stmt The statment to be transformed.
* \return Transformed stmt.
*/
Stmt InjectVirtualThread(Stmt stmt);

/*!
* \brief Lift storage allocation to relevant outpost location
*
* Only do this after vectorization and virtual thread injection completes.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt LiftAllocate(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def build(sch,
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.InjectVirtualThread(stmt)
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
Expand Down
2 changes: 2 additions & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(LiftAllocate);
REGISTER_PASS1(InjectVirtualThread);

} // namespace ir
} // namespace tvm
3 changes: 2 additions & 1 deletion src/arithmetic/canonical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ class Canonical::Internal : public IRMutator {
}
// AttrStmt
Stmt Mutate_(const AttrStmt* op, const Stmt& s) {
if (op->type_key == "thread_extent") {
if (op->type_key == attr::thread_extent ||
op->type_key == attr::virtual_thread) {
++level_counter_;
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
Expand Down
4 changes: 2 additions & 2 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
}

void CodeGenC::PrintStmt(const AttrStmt* op) {
if (op->type_key == "scope") {
if (op->type_key == ir::attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
if (!var_idmap_.count(iv->var.get())) {
Expand All @@ -756,7 +756,7 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
stream << ";\n";
}
}
} else if (op->type_key == "storage_scope") {
} else if (op->type_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
Expand Down
12 changes: 12 additions & 0 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <string>
#include "./codegen_cuda.h"
#include "./codegen_stack_vm.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"

Expand All @@ -22,6 +23,17 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa);
}

void CodeGenCUDA::PrintStmt(const ir::For* op) {
int ext;
CHECK(is_zero(op->min));
if (arith::GetConstInt(op->extent, &ext) &&
ext <= max_auto_unroll_) {
PrintIndent();
stream << "#pragma unroll\n";
}
CodeGenC::PrintStmt(op);
}

void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CodeGenCUDA : public CodeGenC {
bool output_ssa);

// override behavior
void PrintStmt(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
Expand All @@ -37,6 +38,11 @@ class CodeGenCUDA : public CodeGenC {
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;

private:
// magic number to add pragma unroll to it.
// used to generate code that is compact but still unrolls.
int max_auto_unroll_{8};
};

} // namespace codegen
Expand Down
Loading