Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] [opt] Unify ifloordiv implementation over backends in demote_operations #1771

Merged
merged 13 commits into from
Aug 28, 2020
Merged
14 changes: 6 additions & 8 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ class IRVerifier : public BasicStmtVisitor {
break;
}
}
TI_ASSERT_INFO(
found,
"IR broken: stmt {} cannot have operand {}."
" Consider adding `ti.core.toggle_advanced_optimization(False)`."
" If that fixes the problem, please report this bug by opening an"
" issue at https://github.com/taichi-dev/taichi to help us improve."
" Thanks!",
stmt->id, op->id);
TI_ASSERT_INFO(found,
"IR broken: stmt {} cannot have operand {}."
" Please report this bug by opening an issue at"
archibate marked this conversation as resolved.
Show resolved Hide resolved
" https://github.com/taichi-dev/taichi to help us improve."
" Thanks in advance!",
stmt->id, op->id);
}
visible_stmts.back().insert(stmt);
}
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ class KernelGen : public IRVisitor {
if (bin->op_type == BinaryOpType::floordiv) {
if (is_integral(bin->lhs->element_type()) &&
is_integral(bin->rhs->element_type())) {
TI_WARN("Integer floordiv called! It should be taken care by alg_simp");
archibate marked this conversation as resolved.
Show resolved Hide resolved
emit(
"{} {} = {}(sign({}) * {} >= 0 ? abs({}) / abs({}) : sign({}) * "
"(abs({}) + abs({}) - 1) / {});",
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bool die(IRNode *root);
bool simplify(IRNode *root, Kernel *kernel = nullptr);
bool cfg_optimization(IRNode *root, bool after_lower_access);
bool alg_simp(IRNode *root);
bool demote_operations(IRNode *root);
bool binary_op_simplify(IRNode *root);
bool whole_kernel_cse(IRNode *root);
void variable_optimization(IRNode *root, bool after_lower_access);
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ void compile_to_offloads(IRNode *ir,
print("Typechecked");
irpass::analysis::verify(ir);

irpass::demote_operations(ir);
print("Operations Demoted");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to just before Simplified IV as it can impede other optimizations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we only demote_operations in compile_to_executable, not compile_to_offloads?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We may perform some optimizations after compile_to_offloads in async mode.

if (ir->get_kernel()->is_evaluator) {
TI_ASSERT(!grad);
irpass::offload(ir);
Expand Down
102 changes: 102 additions & 0 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "taichi/ir/ir.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/program/program.h"

TLANG_NAMESPACE_BEGIN

// Demote Operations into pieces for backends to deal easier
class DemoteOperations : public BasicStmtVisitor {
public:
using BasicStmtVisitor::visit;
DelayedIRModifier modifier;

DemoteOperations() : BasicStmtVisitor() {
}

void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs;
auto rhs = stmt->rhs;
if (stmt->op_type == BinaryOpType::floordiv) {
if (is_integral(rhs->element_type()) &&
is_integral(lhs->element_type())) {
// @ti.func
// def ifloordiv(a, b):
// r = ti.raw_div(a, b)
// if (a < 0) != (b < 0) and a and b * r != a:
// r = r - 1
// return r
Comment on lines +23 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest not involving b * r != a since I don't think we can optimize it out easily.

What we need is:

a >= 0, b > 0: a / b
a >= 0, b < 0: (a - b + 1) / b
a <= 0, b > 0: (a - b + 1) / b
a <= 0, b < 0: a / b

So ((a < 0) == (b < 0) ? a : (a - b + 1)) / b should be fine.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work for a = 2, b = -3:

import taichi as ti


@ti.func
def ifloordiv(a, b):
    return (a if (a < 0) == (b < 0) else (a - b + 1)) / b


@ti.kernel
def func(a: int, b: int) -> int:
    return ifloordiv(a, b)


a, b = 2, -3
print(func(a, b), a // b)

This is non-trivial, ff2 change it iapr on your own.

Copy link
Contributor

@xumingkuan xumingkuan Aug 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it should be

a >= 0, b > 0: a / b
a >= 0, b < 0: (a - b - 1) / b
a <= 0, b > 0: (a - b + 1) / b
a <= 0, b < 0: a / b

......
BTW when do we encounter floordiv at the frontend? Only when ti.floordiv(a, b)?

Oh, 2 // -3 in Python is -1...

auto ret = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
auto zero = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(0));
auto lhs_ltz =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, lhs, zero.get());
auto rhs_ltz =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, rhs, zero.get());
auto rhs_mul_ret =
Stmt::make<BinaryOpStmt>(BinaryOpType::mul, rhs, ret.get());
auto cond1 = Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne,
lhs_ltz.get(), rhs_ltz.get());
auto cond2 =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs, zero.get());
auto cond3 = Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne,
rhs_mul_ret.get(), lhs);
auto cond12 = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and,
cond1.get(), cond2.get());
auto cond = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and,
cond12.get(), cond3.get());
auto real_ret =
Stmt::make<BinaryOpStmt>(BinaryOpType::add, ret.get(), cond.get());

stmt->replace_with(real_ret.get());
modifier.insert_before(stmt, std::move(ret));
modifier.insert_before(stmt, std::move(zero));
modifier.insert_before(stmt, std::move(lhs_ltz));
modifier.insert_before(stmt, std::move(rhs_ltz));
modifier.insert_before(stmt, std::move(rhs_mul_ret));
modifier.insert_before(stmt, std::move(cond1));
modifier.insert_before(stmt, std::move(cond2));
modifier.insert_before(stmt, std::move(cond3));
modifier.insert_before(stmt, std::move(cond12));
modifier.insert_before(stmt, std::move(cond));
modifier.insert_before(stmt, std::move(real_ret));
modifier.erase(stmt);

} else {
// @ti.func
// def ffloordiv(a, b):
// r = ti.raw_div(a, b)
// return ti.floor(r)
auto div = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
auto floor = Stmt::make<UnaryOpStmt>(UnaryOpType::floor, div.get());
stmt->replace_with(floor.get());
modifier.insert_before(stmt, std::move(div));
modifier.insert_before(stmt, std::move(floor));
modifier.erase(stmt);
}
}
}

static bool run(IRNode *node) {
DemoteOperations demoter;
bool modified = false;
while (true) {
node->accept(&demoter);
if (demoter.modifier.modify_ir())
modified = true;
else
break;
}
return modified;
}
};

namespace irpass {

bool demote_operations(IRNode *root) {
TI_AUTO_PROF;
return DemoteOperations::run(root);
}

} // namespace irpass

TLANG_NAMESPACE_END
4 changes: 2 additions & 2 deletions tests/python/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_true_div():
_test_true_div(ti.f32, -3, ti.i32, 2, ti.i32, -1)


@ti.all_archs
@ti.test()
def test_div_default_ip():
ti.get_runtime().set_default_ip(ti.i64)
z = ti.field(ti.f32, shape=())
Expand All @@ -70,7 +70,7 @@ def func():
assert z[None] == 100000


@ti.all_archs
@ti.test()
def test_floor_div_pythonic():
z = ti.field(ti.i32, shape=())

Expand Down