Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Opt] [ir] Optimize offload #2673

Merged
merged 16 commits into from
Aug 21, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 78 additions & 166 deletions taichi/transforms/offload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ class IdentifyValuesUsedInOtherOffloads : public BasicStmtVisitor {
if (stmt->is<PtrOffsetStmt>()) {
if (local_to_global.find(stmt->cast<PtrOffsetStmt>()->origin) ==
local_to_global.end()) {
auto alloca_stmt = stmt->cast<PtrOffsetStmt>()->origin;
local_to_global[alloca_stmt] = allocate_global(alloca_stmt->ret_type);
if (stmt->cast<PtrOffsetStmt>()->origin->is<AllocaStmt>()) {
auto alloca_stmt = stmt->cast<PtrOffsetStmt>()->origin;
local_to_global[alloca_stmt] = allocate_global(alloca_stmt->ret_type);
}
}
} else {
// stmt might be AllocaStmt, ExternalTensorShapeAlongAxisStmt
Expand Down Expand Up @@ -421,14 +423,28 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
if (stmt->body)
stmt->body->accept(this);
if (stmt->task_type == OffloadedStmt::TaskType::range_for) {
if (!stmt->const_begin)
if (!stmt->const_begin) {
TI_ASSERT(offloaded_ranges_->begin_stmts.find(stmt) !=
offloaded_ranges_->begin_stmts.end())
TI_ASSERT_INFO(local_to_global_offset.find(
offloaded_ranges_->begin_stmts.find(stmt)->second) !=
local_to_global_offset.end(),
"Begin fails.")
stmt->begin_offset =
local_to_global_offset[offloaded_ranges_->begin_stmts.find(stmt)
->second];
if (!stmt->const_end)
}
if (!stmt->const_end) {
TI_ASSERT(offloaded_ranges_->end_stmts.find(stmt) !=
offloaded_ranges_->end_stmts.end())
TI_ASSERT_INFO(local_to_global_offset.find(
offloaded_ranges_->end_stmts.find(stmt)->second) !=
local_to_global_offset.end(),
"End fails.")
stmt->end_offset =
local_to_global_offset[offloaded_ranges_->end_stmts.find(stmt)
->second];
}
}
}

Expand All @@ -441,166 +457,71 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
local_to_global_vector_type[stmt] = ret_type;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[stmt], ret_type);
stmt_to_offloaded[ptr] = stmt_to_offloaded[stmt];
squarefk marked this conversation as resolved.
Show resolved Hide resolved
if (auto tensor_type = stmt->ret_type->cast<TensorType>()) {
LaneAttribute<TypedConstant> zero(std::vector<TypedConstant>(
1, TypedConstant(tensor_type->get_element_type())));
auto const_zero_stmt = replacement.push_back<ConstStmt>(zero);
stmt_to_offloaded[const_zero_stmt] = stmt_to_offloaded[stmt];
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
LaneAttribute<TypedConstant> offset(std::vector<TypedConstant>(
1, TypedConstant(i *
data_type_size(tensor_type->get_element_type()))));
auto const_offset_stmt = replacement.push_back<ConstStmt>(offset);
auto ptr_offset_stmt =
replacement.push_back<PtrOffsetStmt>(ptr, const_offset_stmt);
replacement.push_back<GlobalStoreStmt>(ptr_offset_stmt,
const_zero_stmt);
auto global_store_stmt = replacement.push_back<GlobalStoreStmt>(
ptr_offset_stmt, const_zero_stmt);
stmt_to_offloaded[const_offset_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[ptr_offset_stmt] = stmt_to_offloaded[stmt];
stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt];
}
} else {
LaneAttribute<TypedConstant> zeros(std::vector<TypedConstant>(
stmt->width(), TypedConstant(stmt->ret_type)));
auto const_zeros = replacement.push_back<ConstStmt>(zeros);
replacement.push_back<GlobalStoreStmt>(ptr, const_zeros);
auto global_store_stmt =
replacement.push_back<GlobalStoreStmt>(ptr, const_zeros);
stmt_to_offloaded[global_store_stmt] = stmt_to_offloaded[stmt];
}

stmt->parent->replace_with(stmt, std::move(replacement), false);
// To deal with the same offloaded visit_operand()
stmt_to_offloaded[stmt] = nullptr;
throw IRModified();
}

void visit(PtrOffsetStmt *stmt) override {
auto alloca = stmt->origin;
auto offset = stmt->offset;

if (stmt_to_offloaded[offset] != stmt_to_offloaded[stmt]) {
// Replace local LD/ST with global LD/ST
void visit(LocalLoadStmt *stmt) override {
generic_visit(stmt);
TI_ASSERT(stmt->width() == 1)
auto ptr = stmt->src[0].var;
auto top_level_ptr = ptr;
while (top_level_ptr->is<PtrOffsetStmt>())
squarefk marked this conversation as resolved.
Show resolved Hide resolved
top_level_ptr = top_level_ptr->cast<PtrOffsetStmt>()->origin;
if (top_level_ptr->is<GlobalTemporaryStmt>()) {
VecStatement replacement;
// TODO: offset may not be ConstStmt
auto copy_stmt = offset->as<ConstStmt>()->copy();
auto copy = replacement.push_back(std::move(copy_stmt));
stmt_to_offloaded[copy] = stmt_to_offloaded[stmt];
auto ptr_offset = replacement.push_back<PtrOffsetStmt>(alloca, copy);
auto global_load = replacement.push_back<GlobalLoadStmt>(ptr);
stmt_to_offloaded[global_load] = stmt_to_offloaded[stmt];
stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}

if (!(alloca->is<AllocaStmt>() &&
alloca->cast<AllocaStmt>()->ret_type->is<TensorType>()))
return;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
return;

VecStatement replacement;
auto ret_type = alloca->cast<AllocaStmt>()->ret_type;

auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], ret_type);
replacement.push_back<PtrOffsetStmt>(ptr, stmt->offset);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}

// Replace local LD/ST with global LD/ST
void visit(LocalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
// TensorType Alloca
if (stmt->src[0].var->is<PtrOffsetStmt>() &&
stmt->src[0].var->as<PtrOffsetStmt>()->origin->is<AllocaStmt>()) {
auto alloca =
stmt->src[0].var->as<PtrOffsetStmt>()->origin->as<AllocaStmt>();
if (local_to_global_offset.find(alloca) != local_to_global_offset.end()) {
// Converted to GlobalTemporaryStmt
VecStatement replacement;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], alloca->ret_type);
// TODO: offset may not be ConstStmt
auto copy_stmt = stmt->src[0]
.var->as<PtrOffsetStmt>()
->offset->as<ConstStmt>()
->copy();
auto copy = replacement.push_back(std::move(copy_stmt));
auto ptr_offset = replacement.push_back<PtrOffsetStmt>(ptr, copy);
replacement.push_back<GlobalLoadStmt>(ptr_offset);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}
}
// Scalar Alloca
auto alloca = stmt->src[0].var;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
return;

VecStatement replacement;
auto ret_type = stmt->ret_type;

auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], ret_type);
replacement.push_back<GlobalLoadStmt>(ptr);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}

void visit(LocalStoreStmt *stmt) override {
if (visit_operand(stmt, stmt->locate_operand(&stmt->val)))
generic_visit(stmt);
auto ptr = stmt->dest;
auto top_level_ptr = ptr;
while (top_level_ptr->is<PtrOffsetStmt>())
top_level_ptr = top_level_ptr->cast<PtrOffsetStmt>()->origin;
if (top_level_ptr->is<GlobalTemporaryStmt>()) {
VecStatement replacement;
auto global_store =
replacement.push_back<GlobalStoreStmt>(ptr, stmt->val);
stmt_to_offloaded[global_store] = stmt_to_offloaded[stmt];
stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
TI_ASSERT(stmt->width() == 1);
// TensorType Alloca
if (stmt->dest->is<PtrOffsetStmt>() &&
stmt->dest->as<PtrOffsetStmt>()->origin->is<AllocaStmt>()) {
auto alloca = stmt->dest->as<PtrOffsetStmt>()->origin->as<AllocaStmt>();
if (local_to_global_offset.find(alloca) != local_to_global_offset.end()) {
// Converted to GlobalTemporaryStmt
VecStatement replacement;
auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], alloca->ret_type);
// TODO: offset may not be ConstStmt
auto copy_stmt =
stmt->dest->as<PtrOffsetStmt>()->offset->as<ConstStmt>()->copy();
auto copy = replacement.push_back(std::move(copy_stmt));
auto ptr_offset = replacement.push_back<PtrOffsetStmt>(ptr, copy);
replacement.push_back<GlobalStoreStmt>(ptr_offset, stmt->val);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}
}
// Scalar Alloca
auto alloca = stmt->dest;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
return;

VecStatement replacement;
auto ret_type = stmt->ret_type;

auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], ret_type);
replacement.push_back<GlobalStoreStmt>(ptr, stmt->val);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}

void visit(AtomicOpStmt *stmt) override {
if (!stmt->dest->is<AllocaStmt>()) {
generic_visit(stmt);
return;
}
if (visit_operand(stmt, stmt->locate_operand(&stmt->val)))
throw IRModified();
TI_ASSERT(stmt->width() == 1);
auto alloca = stmt->dest;
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
return;

VecStatement replacement;
auto ret_type = stmt->dest->ret_type;

auto ptr = replacement.push_back<GlobalTemporaryStmt>(
local_to_global_offset[alloca], ret_type);
replacement.push_back<AtomicOpStmt>(stmt->op_type, ptr, stmt->val);

stmt->parent->replace_with(stmt, std::move(replacement));
throw IRModified();
}

bool visit_operand(Stmt *stmt, int index) {
Expand All @@ -617,7 +538,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
return true;
}
}
if (op->is<GlobalPtrStmt>()) {
squarefk marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -626,40 +546,32 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
return true;
}

if (op->is<PtrOffsetStmt>() &&
op->cast<PtrOffsetStmt>()->origin->is<AllocaStmt>()) {
auto alloca = op->as<PtrOffsetStmt>()->origin->as<AllocaStmt>();
if (local_to_global_offset.find(alloca) == local_to_global_offset.end())
return false;

auto global_temporary_stmt = Stmt::make<GlobalTemporaryStmt>(
local_to_global_offset[alloca], alloca->ret_type);
// TODO: offset may not be ConstStmt
auto copy_stmt = op->as<PtrOffsetStmt>()->offset->as<ConstStmt>()->copy();
auto ptr_offset_stmt = Stmt::make<PtrOffsetStmt>(
global_temporary_stmt.get(), copy_stmt.get());
stmt_to_offloaded[copy_stmt.get()] = stmt_to_offloaded[stmt];
stmt_to_offloaded[ptr_offset_stmt.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, ptr_offset_stmt.get());
stmt->insert_before_me(std::move(copy_stmt));
stmt->insert_before_me(std::move(global_temporary_stmt));
stmt->insert_before_me(std::move(ptr_offset_stmt));
return true;
if (local_to_global_offset.find(op) == local_to_global_offset.end()) {
// For cases like ConstStmt
auto copy = op->clone();
stmt_to_offloaded[copy.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, copy.get());
stmt->insert_before_me(std::move(copy));
squarefk marked this conversation as resolved.
Show resolved Hide resolved
} else {
auto global_temporary = Stmt::make<GlobalTemporaryStmt>(
local_to_global_offset[op], op->ret_type);
stmt_to_offloaded[global_temporary.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, global_temporary.get());
if (op->is<AllocaStmt>() || op->ret_type.is_pointer()) {
// For cases like Alloca both TensorType and Scalar which will be
// followed by LocalLoad. Avoid repeated loads here.
stmt->insert_before_me(std::move(global_temporary));
} else {
// For other cases like ArgLoadStmt UnaryOpStmt which needs to load.
auto load = Stmt::make<GlobalLoadStmt>(global_temporary.get());
stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, load.get());
stmt->insert_before_me(std::move(global_temporary));
stmt->insert_before_me(std::move(load));
}
}

if (local_to_global_offset.find(op) == local_to_global_offset.end())
return false;

auto global = Stmt::make<GlobalTemporaryStmt>(local_to_global_offset[op],
op->ret_type);
auto load = Stmt::make<GlobalLoadStmt>(global.get());
stmt_to_offloaded[load.get()] = stmt_to_offloaded[stmt];
stmt->set_operand(index, load.get());
stmt->insert_before_me(std::move(global));
stmt->insert_before_me(std::move(load));
return true;
}

Expand All @@ -675,7 +587,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor {
}

void visit(Stmt *stmt) override {
TI_ASSERT(stmt->width() == 1 || (stmt->ret_type->is<TensorType>()));
TI_ASSERT(stmt->width() == 1)
generic_visit(stmt);
}

Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,24 @@ def func2():
func2()
assert v[1][9] == 9

@ti.kernel
def func3():
tmp = ti.Vector([1, 2, 3], dt=ti.i32)
for i in range(3):
tmp[i] = i * i
assert tmp[0] == 0
assert tmp[1] == 1
assert tmp[2] == 4

func3()

@ti.kernel
def func4(k: ti.i32):
tmp = ti.Vector([k, k * 2, k * 3])

with pytest.raises(Exception):
func4(10)


@ti.test(arch=ti.cpu)
def test_matrix_constant_index():
Expand Down