diff --git a/taichi/analysis/verify.cpp b/taichi/analysis/verify.cpp index 07350aa5889b4..b643f58a8a968 100644 --- a/taichi/analysis/verify.cpp +++ b/taichi/analysis/verify.cpp @@ -44,10 +44,12 @@ class IRVerifier : public BasicStmtVisitor { TI_ASSERT_INFO( found, "IR broken: stmt {} cannot have operand {}." - " Consider adding `ti.core.toggle_advanced_optimization(False)`." - " If that fixes the problem, please report this bug by opening an" - " issue at https://github.com/taichi-dev/taichi to help us improve." - " Thanks!", + " If you are using autodiff, please check" + " https://taichi.readthedocs.io/en/stable/" + "differentiable_programming.html#kernel-simplicity-rule." + " If it doesn't help, please report this bug by opening an issue at" + " https://github.com/taichi-dev/taichi to help us improve." + " Thanks in advance!", stmt->id, op->id); } visible_stmts.back().insert(stmt); diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index fc1742537a64e..2cff5466c6f24 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -497,6 +497,7 @@ class KernelGen : public IRVisitor { const auto rhs_name = bin->rhs->short_name(); const auto bin_name = bin->short_name(); if (bin->op_type == BinaryOpType::floordiv) { + TI_WARN("floordiv called! It should be taken care by demote_operations"); if (is_integral(bin->lhs->element_type()) && is_integral(bin->rhs->element_type())) { emit( diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index d7366594abc8c..fe3fcda5d44fb 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -21,6 +21,7 @@ bool die(IRNode *root); bool simplify(IRNode *root, Kernel *kernel = nullptr); bool cfg_optimization(IRNode *root, bool after_lower_access); bool alg_simp(IRNode *root); +bool demote_operations(IRNode *root); bool binary_op_simplify(IRNode *root); bool whole_kernel_cse(IRNode *root); void variable_optimization(IRNode *root, bool after_lower_access); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 6ea925785bf35..13830af85044b 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -50,6 +50,10 @@ void compile_to_offloads(IRNode *ir, if (ir->get_kernel()->is_evaluator) { TI_ASSERT(!grad); + + irpass::demote_operations(ir); + print("Operations Demoted"); + irpass::offload(ir); print("Offloaded"); irpass::analysis::verify(ir); @@ -164,6 +168,9 @@ void offload_to_executable(IRNode *ir, print("Atomics demoted"); irpass::analysis::verify(ir); + irpass::demote_operations(ir); + print("Operations demoted"); + irpass::full_simplify(ir, lower_global_access); print("Simplified IV"); diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp new file mode 100644 index 0000000000000..1c4a87644fcbb --- /dev/null +++ b/taichi/transforms/demote_operations.cpp @@ -0,0 +1,102 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" +#include "taichi/program/program.h" + +TLANG_NAMESPACE_BEGIN + +// Demote Operations into pieces for backends to deal easier +class DemoteOperations : public BasicStmtVisitor { + public: + using BasicStmtVisitor::visit; + DelayedIRModifier modifier; + + DemoteOperations() : BasicStmtVisitor() { + } + + void visit(BinaryOpStmt *stmt) override { + auto lhs = stmt->lhs; + auto rhs = stmt->rhs; + if (stmt->op_type == BinaryOpType::floordiv) { + if (is_integral(rhs->element_type()) && + is_integral(lhs->element_type())) { + // @ti.func + // def ifloordiv(a, b): + // r = ti.raw_div(a, b) + // if (a < 0) != (b < 0) and a and b * r != a: + // r = r - 1 + // return r + auto ret = Stmt::make(BinaryOpType::div, lhs, rhs); + auto zero = Stmt::make(LaneAttribute(0)); + auto lhs_ltz = + Stmt::make(BinaryOpType::cmp_lt, lhs, zero.get()); + auto rhs_ltz = + Stmt::make(BinaryOpType::cmp_lt, rhs, zero.get()); + auto rhs_mul_ret = + Stmt::make(BinaryOpType::mul, rhs, ret.get()); + auto cond1 = Stmt::make(BinaryOpType::cmp_ne, + lhs_ltz.get(), rhs_ltz.get()); + auto cond2 = + Stmt::make(BinaryOpType::cmp_ne, lhs, zero.get()); + auto cond3 = Stmt::make(BinaryOpType::cmp_ne, + rhs_mul_ret.get(), lhs); + auto cond12 = Stmt::make(BinaryOpType::bit_and, + cond1.get(), cond2.get()); + auto cond = Stmt::make(BinaryOpType::bit_and, + cond12.get(), cond3.get()); + auto real_ret = + Stmt::make(BinaryOpType::add, ret.get(), cond.get()); + + stmt->replace_with(real_ret.get()); + modifier.insert_before(stmt, std::move(ret)); + modifier.insert_before(stmt, std::move(zero)); + modifier.insert_before(stmt, std::move(lhs_ltz)); + modifier.insert_before(stmt, std::move(rhs_ltz)); + modifier.insert_before(stmt, std::move(rhs_mul_ret)); + modifier.insert_before(stmt, std::move(cond1)); + modifier.insert_before(stmt, std::move(cond2)); + modifier.insert_before(stmt, std::move(cond3)); + modifier.insert_before(stmt, std::move(cond12)); + modifier.insert_before(stmt, std::move(cond)); + modifier.insert_before(stmt, std::move(real_ret)); + modifier.erase(stmt); + + } else { + // @ti.func + // def ffloordiv(a, b): + // r = ti.raw_div(a, b) + // return ti.floor(r) + auto div = Stmt::make(BinaryOpType::div, lhs, rhs); + auto floor = Stmt::make(UnaryOpType::floor, div.get()); + stmt->replace_with(floor.get()); + modifier.insert_before(stmt, std::move(div)); + modifier.insert_before(stmt, std::move(floor)); + modifier.erase(stmt); + } + } + } + + static bool run(IRNode *node) { + DemoteOperations demoter; + bool modified = false; + while (true) { + node->accept(&demoter); + if (demoter.modifier.modify_ir()) + modified = true; + else + break; + } + return modified; + } +}; + +namespace irpass { + +bool demote_operations(IRNode *root) { + TI_AUTO_PROF; + return DemoteOperations::run(root); +} + +} // namespace irpass + +TLANG_NAMESPACE_END diff --git a/tests/python/test_div.py b/tests/python/test_div.py index 6fa848eee9d1b..3f684610062fd 100644 --- a/tests/python/test_div.py +++ b/tests/python/test_div.py @@ -56,7 +56,7 @@ def test_true_div(): _test_true_div(ti.f32, -3, ti.i32, 2, ti.i32, -1) -@ti.all_archs +@ti.test() def test_div_default_ip(): ti.get_runtime().set_default_ip(ti.i64) z = ti.field(ti.f32, shape=()) @@ -70,7 +70,7 @@ def func(): assert z[None] == 100000 -@ti.all_archs +@ti.test() def test_floor_div_pythonic(): z = ti.field(ti.i32, shape=())