diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index 10ffe95f653d..e45e25a265d0 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -15,22 +15,22 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - if (dynamic_cast(args.at(0).sptr.get())) { - *ret = Simplify(args.at(0).operator Expr()); - } else { + if (dynamic_cast(args.at(0).sptr.get())) { *ret = Simplify(args.at(0).operator Stmt()); + } else { + *ret = Simplify(args.at(0).operator Expr()); } }); TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - CHECK(args.at(1).type_id == kNodeHandle); - if (dynamic_cast(args.at(0).sptr.get())) { - *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); - } else { + if (dynamic_cast(args.at(0).sptr.get())) { + CHECK(args.at(1).type_id == kNodeHandle); *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + } else { + Expr a = args.at(0).operator Expr(); + Expr b = args.at(1).operator Expr(); + *ret = Equal(a, b); } }); diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index ebffc58805f3..b9e8d501e68a 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -8,6 +8,9 @@ def test_simplify(): assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) + let = tvm.make.Let(x, 1, x + 3) + e4 = tvm.ir_pass.Simplify(let) + assert(tvm.ir_pass.Equal(e4, 4)) def test_verify_ssa(): @@ -20,8 +23,9 @@ def test_verify_ssa(): def test_convert_ssa(): x = tvm.Var('x') y = tvm.Var() - let = tvm.make.Let(x, 1, x + 1) - z = tvm.make.Evaluate(let + let) + let1 = tvm.make.Let(x, 1, x + 1) + let2 = tvm.make.Let(x, 1, x + y) + z = tvm.make.Evaluate(let1 + let2) assert(not tvm.ir_pass.VerifySSA(z)) z_ssa = tvm.ir_pass.ConvertSSA(z) assert(tvm.ir_pass.VerifySSA(z_ssa))