diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index a4120d20288f..96db7d762cae 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -394,6 +394,7 @@ bool IsAllPositiveConstant(const Expr& expr) { static const auto& reshape_op = Op::Get("reshape"); static const auto& transpose_op = Op::Get("transpose"); static const auto& squeeze_op = Op::Get("squeeze"); + static const auto& repeat_op = Op::Get("repeat"); // peel through a few common transform ops. if (const auto* constant = expr.as()) { @@ -419,7 +420,7 @@ bool IsAllPositiveConstant(const Expr& expr) { } else if (const auto* op = expr.as()) { // tail recursion. if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op || - op->op == squeeze_op) { + op->op == squeeze_op || op->op == repeat_op) { return IsAllPositiveConstant(op->args[0]); } else { return false; diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 7518380de3b1..7334308e4a16 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -427,16 +427,6 @@ bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& exp return false; } -Expr InferType(const Expr& expr) { - auto mod = IRModule::FromExpr(expr); - mod = transform::InferType()(mod); - if (expr.as()) { - return mod->Lookup("main"); - } else { - return mod->Lookup("main").as()->body; - } -} - Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { IRModule mod(m->functions, m->type_definitions, m->Imports()); GlobalVarSupply global_var_supply = GlobalVarSupply(mod); diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index d03939e09ea8..aa4ef03c95a4 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -102,6 +102,24 @@ namespace relay { LOG(FATAL) << "unknown data type " << type; \ } +/*! + * \brief Try to do the type inference over expr: + * + * Do the infer_type over each node in expr + * + * \param expr The IR expression + * \return infered expr if succeed. + */ +inline Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + /*! * \brief Try to match lhs and rhs via broadcasting rule, such that: * @@ -121,6 +139,17 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTyp size_t base = tlhs->shape.size() - trhs->shape.size(); size_t j = 0; + // handle case trhs is simple constant + if (trhs->shape.size() == 0 && rhs_value != nullptr && lhs_axes.size() > 0) { + *rhs_value = MakeExpandDims(*rhs_value, 0, lhs_axes.size()); + for (size_t i = 0; i < lhs_axes.size(); i++) { + int repeat_value = + tlhs->shape[static_cast(lhs_axes[j]->value)].as()->value; + *rhs_value = MakeRepeat(*rhs_value, repeat_value, i); + } + return true; + } + ObjectPtr squeeze_attrs; if (rhs_value != nullptr) { squeeze_attrs = make_object(); diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 6cae728b304f..923a18f7bc93 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -847,6 +847,37 @@ class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { DFPattern c2_; }; +/*! \brief Simplifying x+x to x*2 */ +class SimplifyAdd : public DFPatternRewrite { + public: + SimplifyAdd() { + x_ = IsWildcard(); + y_ = IsWildcard(); + pattern_ = IsOp("add")({x_, y_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const override { + Type pre_type = pre->checked_type_; + auto dtype = pre_type.as()->dtype; + auto x = node_map[x_][0]; + auto y = node_map[y_][0]; + auto data_type = Downcast(x->checked_type()); + + if (x == y) { + Expr value; + value = MakeConstantScalar(dtype, 2); + return InferType(Call(Op::Get("multiply"), {x, value})); + } + return post; + } + + private: + /*! \brief Pattern input */ + DFPattern x_; + DFPattern y_; +}; + /*! \brief Simplifying x/sqrt to x*sqrt */ class SimplifyRSqrt : public DFPatternRewrite { public: @@ -925,6 +956,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); + composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); composer.AddRewrite(); diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 12fc722d8604..8ffa3ef832e0 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -20,6 +20,12 @@ from tvm import te from tvm import relay from tvm.relay import transform +from tvm.relay.testing import create_workload +from tvm.relay.build_module import bind_params_by_name + + +def initializer(_, param): + param = np.zeros(param.shape) def _get_positive_scale(size): @@ -636,6 +642,50 @@ def check(shape, in_channels, channels, blocking): check((2, 2, 10, 10, 2), 4, 8, (2, 2)) +def test_fold_bwd_simple_constant(): + def before(data, weight, out_bias, channels): + y = relay.nn.conv2d( + data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1) + ) + + y = relay.add(y, out_bias) + c2 = relay.const(2.0) + y = relay.nn.relu(y) + y = relay.multiply(y, c2) + mod, params = create_workload(y, initializer) + mod["main"] = bind_params_by_name(mod["main"], params) + return mod + + def expected(data, weight, out_bias, channels): + y0 = relay.nn.conv2d( + data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1) + ) + y0 = relay.add(y0, out_bias) + y0 = relay.nn.relu(y0) + mod, params = create_workload(y0, initializer) + mod["main"] = bind_params_by_name(mod["main"], params) + return mod + + def check(shape, channels): + x = relay.var("data", relay.TensorType(shape, "float32")) + weight = relay.var("weight") + out_bias = relay.var("in_bias", shape=(channels, 1, 1)) + + y0 = before(x, weight, out_bias, channels) + remove_last_multiply = tvm.transform.Sequential( + [ + relay.transform.InferType(), + relay.transform.FoldScaleAxis(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + y0 = remove_last_multiply(y0) + _expect = expected(x, weight, out_bias, channels) + tvm.ir.assert_structural_equal(y0, _expect) + + check((1, 3, 200, 200), 16) + + def test_fold_bwd_dual_consumer(): def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] @@ -1211,6 +1261,7 @@ def check(shape, in_channels, channels, blocking): test_fold_fwd_relu_fail() test_fold_fwd_negative_scale() test_fold_fwd_dense() + test_fold_bwd_simple_constant() test_fold_bwd_simple() test_fold_bwd_dual_path() test_fold_bwd_dual_consumer() diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 6df07966eb0a..fa9773b8e3d9 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -729,5 +729,20 @@ def expected(): assert tvm.ir.structural_equal(opt, ref) +def test_simplify_add(): + x = relay.var("x", shape=(1, 3, 100, 100), dtype="float32") + + def before(): + return relay.add(x, x) + + def expected(): + s = relay.const(2.0) + return relay.multiply(x, s) + + opt = run_opt_pass(before(), transform.SimplifyExpr()) + ref = run_infer_type(expected()) + assert tvm.ir.structural_equal(opt, ref) + + if __name__ == "__main__": pytest.main([__file__])