Skip to content

Commit

Permalink
[Relay] Enhancement for fold_scale_axis and simplify_expr (#13275)
Browse files Browse the repository at this point in the history
add(%1, %1) convert to multiply(%1, 2f); enhance fold_scale_axis to fold multiply(%1, 2f) into conv

Signed-off-by: Lei Wen <[email protected]>
Co-authored-by: Lei Wen <[email protected]>
  • Loading branch information
leiwen83 and wenlei03 authored Nov 3, 2022
1 parent 90ed632 commit b1a099b
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ bool IsAllPositiveConstant(const Expr& expr) {
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");
static const auto& repeat_op = Op::Get("repeat");

// peel through a few common transform ops.
if (const auto* constant = expr.as<ConstantNode>()) {
Expand All @@ -419,7 +420,7 @@ bool IsAllPositiveConstant(const Expr& expr) {
} else if (const auto* op = expr.as<CallNode>()) {
// tail recursion.
if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op ||
op->op == squeeze_op) {
op->op == squeeze_op || op->op == repeat_op) {
return IsAllPositiveConstant(op->args[0]);
} else {
return false;
Expand Down
10 changes: 0 additions & 10 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,16 +427,6 @@ bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& exp
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main");
} else {
return mod->Lookup("main").as<FunctionNode>()->body;
}
}

Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
IRModule mod(m->functions, m->type_definitions, m->Imports());
GlobalVarSupply global_var_supply = GlobalVarSupply(mod);
Expand Down
29 changes: 29 additions & 0 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ namespace relay {
LOG(FATAL) << "unknown data type " << type; \
}

/*!
* \brief Try to do the type inference over expr:
*
* Do the infer_type over each node in expr
*
* \param expr The IR expression
* \return infered expr if succeed.
*/
inline Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
if (expr.as<FunctionNode>()) {
return mod->Lookup("main");
} else {
return mod->Lookup("main").as<FunctionNode>()->body;
}
}

/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
Expand All @@ -121,6 +139,17 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTyp
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;

// handle case trhs is simple constant
if (trhs->shape.size() == 0 && rhs_value != nullptr && lhs_axes.size() > 0) {
*rhs_value = MakeExpandDims(*rhs_value, 0, lhs_axes.size());
for (size_t i = 0; i < lhs_axes.size(); i++) {
int repeat_value =
tlhs->shape[static_cast<size_t>(lhs_axes[j]->value)].as<IntImmNode>()->value;
*rhs_value = MakeRepeat(*rhs_value, repeat_value, i);
}
return true;
}

ObjectPtr<SqueezeAttrs> squeeze_attrs;
if (rhs_value != nullptr) {
squeeze_attrs = make_object<SqueezeAttrs>();
Expand Down
32 changes: 32 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,37 @@ class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite {
DFPattern c2_;
};

/*! \brief Simplifying x+x to x*2 */
class SimplifyAdd : public DFPatternRewrite {
public:
SimplifyAdd() {
x_ = IsWildcard();
y_ = IsWildcard();
pattern_ = IsOp("add")({x_, y_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
Type pre_type = pre->checked_type_;
auto dtype = pre_type.as<TensorTypeNode>()->dtype;
auto x = node_map[x_][0];
auto y = node_map[y_][0];
auto data_type = Downcast<TensorType>(x->checked_type());

if (x == y) {
Expr value;
value = MakeConstantScalar(dtype, 2);
return InferType(Call(Op::Get("multiply"), {x, value}));
}
return post;
}

private:
/*! \brief Pattern input */
DFPattern x_;
DFPattern y_;
};

/*! \brief Simplifying x/sqrt to x*sqrt */
class SimplifyRSqrt : public DFPatternRewrite {
public:
Expand Down Expand Up @@ -925,6 +956,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
composer.AddRewrite<ConcretizeCastLikeRewrite>();
composer.AddRewrite<SimplifyAdd>();
composer.AddRewrite<SimplifyRSqrt>();
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
Expand Down
51 changes: 51 additions & 0 deletions tests/python/relay/test_pass_fold_scale_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
from tvm import te
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import create_workload
from tvm.relay.build_module import bind_params_by_name


def initializer(_, param):
param = np.zeros(param.shape)


def _get_positive_scale(size):
Expand Down Expand Up @@ -636,6 +642,50 @@ def check(shape, in_channels, channels, blocking):
check((2, 2, 10, 10, 2), 4, 8, (2, 2))


def test_fold_bwd_simple_constant():
def before(data, weight, out_bias, channels):
y = relay.nn.conv2d(
data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
)

y = relay.add(y, out_bias)
c2 = relay.const(2.0)
y = relay.nn.relu(y)
y = relay.multiply(y, c2)
mod, params = create_workload(y, initializer)
mod["main"] = bind_params_by_name(mod["main"], params)
return mod

def expected(data, weight, out_bias, channels):
y0 = relay.nn.conv2d(
data=data, weight=weight, kernel_size=(3, 3), channels=16, padding=(1, 1)
)
y0 = relay.add(y0, out_bias)
y0 = relay.nn.relu(y0)
mod, params = create_workload(y0, initializer)
mod["main"] = bind_params_by_name(mod["main"], params)
return mod

def check(shape, channels):
x = relay.var("data", relay.TensorType(shape, "float32"))
weight = relay.var("weight")
out_bias = relay.var("in_bias", shape=(channels, 1, 1))

y0 = before(x, weight, out_bias, channels)
remove_last_multiply = tvm.transform.Sequential(
[
relay.transform.InferType(),
relay.transform.FoldScaleAxis(),
]
)
with tvm.transform.PassContext(opt_level=3):
y0 = remove_last_multiply(y0)
_expect = expected(x, weight, out_bias, channels)
tvm.ir.assert_structural_equal(y0, _expect)

check((1, 3, 200, 200), 16)


def test_fold_bwd_dual_consumer():
def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking):
args = [x, conv_weight, out_bias]
Expand Down Expand Up @@ -1211,6 +1261,7 @@ def check(shape, in_channels, channels, blocking):
test_fold_fwd_relu_fail()
test_fold_fwd_negative_scale()
test_fold_fwd_dense()
test_fold_bwd_simple_constant()
test_fold_bwd_simple()
test_fold_bwd_dual_path()
test_fold_bwd_dual_consumer()
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,5 +729,20 @@ def expected():
assert tvm.ir.structural_equal(opt, ref)


def test_simplify_add():
x = relay.var("x", shape=(1, 3, 100, 100), dtype="float32")

def before():
return relay.add(x, x)

def expected():
s = relay.const(2.0)
return relay.multiply(x, s)

opt = run_opt_pass(before(), transform.SimplifyExpr())
ref = run_infer_type(expected())
assert tvm.ir.structural_equal(opt, ref)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit b1a099b

Please sign in to comment.