Skip to content

Commit

Permalink
add temporary variable, fix TI_ASSERT_INFO and abstract SquashPtrOffset
Browse files Browse the repository at this point in the history
  • Loading branch information
squarefk committed Aug 18, 2021
1 parent 4994e01 commit b7a89ab
Showing 1 changed file with 43 additions and 20 deletions.
63 changes: 43 additions & 20 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@ TLANG_NAMESPACE_BEGIN
namespace irpass {
namespace {

class SquashPtrOffset : public IRVisitor {
public:
SquashPtrOffset() {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}
void visit(Stmt *stmt) override {
top_level_ptr = stmt;
}
void visit(PtrOffsetStmt *stmt) override {
stmt->origin->accept(this);
}
static Stmt *run(Stmt *root) {
SquashPtrOffset v;
root->accept(&v);
return v.top_level_ptr;
}

private:
Stmt *top_level_ptr = nullptr;
};

// Offloaded local variables to its offset in the global tmps memory.
using StmtToOffsetMap = std::unordered_map<const Stmt *, std::size_t>;

Expand Down Expand Up @@ -305,13 +327,11 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
// Directly insert copies of ConstStmts later
if (stmt->is<ConstStmt>())
return;
auto top_level_ptr = SquashPtrOffset::run(stmt);
// We don't support storing a pointer for now.
if (stmt->is<GlobalPtrStmt>())
if (top_level_ptr->is<GlobalPtrStmt>())
return;
// Not yet allocated
auto top_level_ptr = stmt;
while (top_level_ptr->is<PtrOffsetStmt>())
top_level_ptr = top_level_ptr->cast<PtrOffsetStmt>()->origin;
if (local_to_global.find(top_level_ptr) == local_to_global.end()) {
local_to_global[top_level_ptr] = allocate_global(top_level_ptr->ret_type);
}
Expand Down Expand Up @@ -445,12 +465,13 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
local_to_global_vector_type[stmt] = ret_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[stmt], ret_type);
stmt_to_offloaded[ptr] = stmt_to_offloaded[stmt];
auto offloaded = stmt_to_offloaded[stmt];
stmt_to_offloaded[ptr] = offloaded;
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
LaneAttribute<TypedConstant> zero(std::vector<TypedConstant>(
1, TypedConstant(tensor_type->get_element_type())));
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
stmt_to_offloaded[const_zero_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[const_zero_stmt] = offloaded;
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
LaneAttribute<TypedConstant> offset(std::vector<TypedConstant>(
1, TypedConstant(i *
Expand All @@ -460,17 +481,17 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
replacement.push_back<PtrOffsetStmt>(ptr, const_offset_stmt);
auto global_store_stmt = replacement.push_back<GlobalStoreStmt>(
ptr_offset_stmt, const_zero_stmt);
stmt_to_offloaded[const_offset_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[ptr_offset_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[const_offset_stmt] = offloaded;
stmt_to_offloaded[ptr_offset_stmt] = offloaded;
stmt_to_offloaded[global_store_stmt] = offloaded;
}
} else {
LaneAttribute<TypedConstant> zeros(std::vector<TypedConstant>(
stmt->width(), TypedConstant(stmt->ret_type)));
auto const_zeros = replacement.push_back<ConstStmt>(zeros);
auto global_store_stmt =
replacement.push_back<GlobalStoreStmt>(ptr, const_zeros);
stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[global_store_stmt] = offloaded;
}

stmt->parent->replace_with(stmt, std::move(replacement), false);
Expand All @@ -484,9 +505,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
generic_visit(stmt);
TI_ASSERT(stmt->width() == 1)
auto ptr = stmt->src[0].var;
auto top_level_ptr = ptr;
while (top_level_ptr->is<PtrOffsetStmt>())
top_level_ptr = top_level_ptr->cast<PtrOffsetStmt>()->origin;
auto top_level_ptr = SquashPtrOffset::run(ptr);
if (top_level_ptr->is<GlobalTemporaryStmt>()) {
VecStatement replacement;
auto global_load = replacement.push_back<GlobalLoadStmt>(ptr);
Expand All @@ -499,9 +518,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
void visit(LocalStoreStmt *stmt) override {
generic_visit(stmt);
auto ptr = stmt->dest;
auto top_level_ptr = ptr;
while (top_level_ptr->is<PtrOffsetStmt>())
top_level_ptr = top_level_ptr->cast<PtrOffsetStmt>()->origin;
auto top_level_ptr = SquashPtrOffset::run(ptr);
if (top_level_ptr->is<GlobalTemporaryStmt>()) {
VecStatement replacement;
auto global_store =
Expand All @@ -520,25 +537,31 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
return false;
if (stmt_to_offloaded[stmt] == stmt_to_offloaded[op]) // same OffloadedStmt
return false;

auto offloaded = stmt_to_offloaded[stmt];

if (op->is<GlobalPtrStmt>()) {
auto copy = op->clone();
copy->as<GlobalPtrStmt>()->activate = false;
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt_to_offloaded[copy.get()] = offloaded;
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
return true;
}

if (local_to_global_offset.find(op) == local_to_global_offset.end()) {
TI_ASSERT_INFO(op->is<ConstStmt>() || op->is<PtrOffsetStmt>() ||
op->is<GlobalTemporaryStmt>(),
"{} is not allowed here.", op->type());
// For cases like ConstStmt
auto copy = op->clone();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt_to_offloaded[copy.get()] = offloaded;
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
} else {
auto global_temporary = Stmt::make<GlobalTemporaryStmt>(
local_to_global_offset[op], op->ret_type);
stmt_to_offloaded[global_temporary.get()] = stmt_to_offloaded[stmt];
stmt_to_offloaded[global_temporary.get()] = offloaded;
stmt->set_operand(index, global_temporary.get());
if (op->is<AllocaStmt>() || op->ret_type.is_pointer()) {
// For cases like Alloca both TensorType and Scalar which will be
Expand All @@ -547,7 +570,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
} else {
// For other cases like ArgLoadStmt UnaryOpStmt which needs to load.
auto load = Stmt::make<GlobalLoadStmt>(global_temporary.get());
stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt];
stmt_to_offloaded[load.get()] = offloaded;
stmt->set_operand(index, load.get());
stmt->insert_before_me(std::move(global_temporary));
stmt->insert_before_me(std::move(load));
Expand Down

0 comments on commit b7a89ab

Please sign in to comment.