Skip to content

Commit

Permalink
[Relay][Quantize] Use fixed point mulplications (apache#4160)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and kevinthesun committed Oct 30, 2019
1 parent 08776e4 commit dbc1cc7
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 19 deletions.
4 changes: 4 additions & 0 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class QConfig(NodeBase):
"do_simulation": False,
"round_for_shift": True,
"debug_enabled_ops": None,
"rounding": "UPWARD"
}

# pylint: disable=no-member
Expand Down Expand Up @@ -160,6 +161,9 @@ def qconfig(**kwargs):
is None, which means will try to call all operartors' annotate rewrite
function.
rounding: "UPWARD" or "TONEAREST"
Rounding direction for fixed point multiplications.
Returns
-------
config: QConfig
Expand Down
3 changes: 2 additions & 1 deletion src/relay/pass/quantize/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", ";
p->stream << "rounding==" << op->rounding;
p->stream << ")";
});

Expand Down
2 changes: 2 additions & 0 deletions src/relay/pass/quantize/quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class QConfigNode : public Node {
bool do_simulation = false;
bool round_for_shift = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
std::string rounding = "UPWARD";

void VisitAttrs(AttrVisitor* v) {
v->Visit("nbit_input", &nbit_input);
Expand All @@ -88,6 +89,7 @@ class QConfigNode : public Node {
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("rounding", &rounding);
}

static constexpr const char* _type_key = "relay.quantize.QConfig";
Expand Down
23 changes: 13 additions & 10 deletions src/relay/pass/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "../pattern_util.h"
#include "../../qnn/util.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -97,7 +98,9 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {


/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
const Array<IndexExpr> &data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;

Expand All @@ -110,9 +113,8 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
data = qnn::FixedPointMultiply(Cast(data, Int(64)), factor, data_shape, cfg->rounding);
return Cast(data, dtype);
}
}

Expand Down Expand Up @@ -164,11 +166,12 @@ Expr QuantizeRealize(const Call& ref_call,
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
data = Cast(data, Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape,
cfg->rounding);
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
}

Expand Down Expand Up @@ -355,7 +358,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
}

*dtype_ptr = dtype;
Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ TVM_REGISTER_NODE_TYPE(RequantizeAttrs);

// Lowering of qnn.requantize op



/*
* \brief Lower requantize to a sequence of ops.
* \param input_tensor The input tensor to requantize op.
Expand Down Expand Up @@ -73,8 +71,8 @@ Expr RequantizeLower(const Expr& input_tensor, const RequantizeAttrs* param,
// 2) If the input and output scales are same, we can skip the fixed point multiplication.
auto scaled_int64_t = tensor;
if (param->input_scale != param->output_scale) {
scaled_int64_t = FixedPointMuliply(scaled_int64_t, double_multiplier, input_shape,
param->rounding);
scaled_int64_t =
FixedPointMultiply(scaled_int64_t, double_multiplier, input_shape, param->rounding);
}

// 3) Add the output zero point.
Expand Down
4 changes: 3 additions & 1 deletion src/relay/qnn/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
return std::make_pair(significand, exponent);
}

Expr FixedPointMuliply(Expr tensor, double multiplier,
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape, const std::string& rounding) {
// Choose high precision datatype to be int64. This is for avoiding overflow
// in multiplication of two int32 values.
Expand Down Expand Up @@ -121,6 +121,8 @@ Expr FixedPointMuliply(Expr tensor, double multiplier,
auto zero_t = Zeros(input_shape, hp_dtype);
round_scalar =
Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
} else {
LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
}
// Add the rounding scalar.
tensor = Add(tensor, round_scalar);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/qnn/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ static inline int64_t get_const_int(const tvm::Expr& x) {
* 2) Round the result.
* 3) Right shift the result
*/
Expr FixedPointMuliply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);
Expr FixedPointMultiply(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape,
const std::string& rounding);

} // namespace qnn
} // namespace relay
Expand Down

0 comments on commit dbc1cc7

Please sign in to comment.