Skip to content

Commit

Permalink
Dnnl style codegen (apache#1)
Browse files Browse the repository at this point in the history
* Checkpoint, nothing works

* DNNL based codegen almost works

* Work in dnnl style

* Work in dnnl style

* Arg passing works

* Work in dnnl style

* Codegen somewhat works

* Requantization not working

* Codegen works

* Remove headsail_old
  • Loading branch information
vilukissa68 authored Oct 4, 2024
1 parent 20e5bde commit d24af7d
Show file tree
Hide file tree
Showing 5 changed files with 905 additions and 182 deletions.
279 changes: 262 additions & 17 deletions python/tvm/relay/op/contrib/headsail.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,18 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import logging
import tvm.ir
from ...dataflow_pattern import wildcard, is_op
from tvm import relay
from ...dataflow_pattern import DFPatternCallback, is_constant, is_expr, is_op, rewrite, wildcard
from tvm.relay.expr import Call, GlobalVar, TupleGetItem, const
from tvm.relay import transform
from .register import register_pattern_table

from ..strategy.generic import is_depthwise_conv2d
logger = logging.getLogger("HEADSAIL")

conv2d_counter = True

def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
Expand All @@ -53,32 +61,269 @@ def _register_external_op_helper(op_name, supported=True):
"""
@tvm.ir.register_op_attr(op_name, "target.headsail")
def _func_wrapper(expr):
args = expr.args
typ = args[0].checked_type
if typ.dtype != "int8":
return False

global conv2d_counter
if conv2d_counter == True:
conv2d_counter = False
logger.info(expr.span)
return supported

return _func_wrapper


#_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.relu")
#_register_external_op_helper("add")

#_register_external_op_helper("qnn.add")
#_register_external_op_helper("qnn.conv2d")
#_register_external_op_helper("qnn.relu")

def make_pattern(with_bias=True):
# Special case to handle tflite models converted to relay with fused activation
def qnn_tflite_conv2d_bias_relu():
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op('nn.conv2d')(data, weight)
if with_bias:
conv_out = is_op('add')(conv, bias)
else:
conv_out = conv
return is_op('nn.relu')(conv_out)
pattern = is_op("qnn.conv2d")(
data, weight, is_constant(), is_constant(), is_constant(), is_constant()
)
pattern = is_op("nn.bias_add")(pattern, bias)
pattern = is_op("qnn.requantize")(
pattern, is_constant(), is_constant(), is_constant(), is_constant()
)
pattern = is_op("clip")(pattern)
return pattern

def make_qnn_conv2d_pattern():
"""Make qnn.conv2d based pattern supported by DNNL
Returns
-------
pattern : Tuple(pattern_name, CallPattern)
Created pattern name, along with its CallPattern.
"""
data = wildcard()
weight = is_constant()
bias = is_constant()
o_scl = is_constant()
dst_zp = is_constant()
act_scl = is_constant()
sum_scl = is_constant()
sum_src = wildcard()

zero_zp = is_expr(const(0, dtype="int32"))

pat = is_op("qnn.conv2d")(data, weight, zero_zp, zero_zp, is_constant(), is_constant())
pat = is_op("cast")(pat)
pat = is_op("add")(pat, bias) | pat # optional bias
pat = is_op("multiply")(pat, o_scl)
pat = is_op("clip")(pat) # TBD, not only clip
pat = is_op("multiply")(pat, act_scl) | pat # optional multiply. Ex: act_scl == 1
pat = is_op("add")(pat, sum_scl * is_op("cast")(sum_src)) | pat # optional sum
pat = is_op("add")(pat, dst_zp) | pat # optional dst_zp, can be dst_zp == 0
pat = is_op("cast")(pat)
return pat

@register_pattern_table("headsail")
def pattern_table():
conv2d_bias_relu_pat = ("headsail.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("headsail.conv2d_relu", make_pattern(with_bias=False))
patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
return []
return patterns
tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", qnn_tflite_conv2d_bias_relu())
#tflite_conv2d_bias_relu = ("headsail.tflite_conv2d_bias_relu", make_qnn_conv2d_pattern())
#tflite_conv2d_bias= ("headsail.tflite_conv2d_bias", qnn_tflite_conv2d_bias())
return [tflite_conv2d_bias_relu]
#return [tflite_conv2d_bias_relu, tflite_conv2d_b//ias]

class LegalizeQnnOpForHeadsail(DFPatternCallback):
"""Legalize QNN based patterns to match DNNL
original pattern:
OP = qnn.conv2d
%1 = OP<int>(SRC, WGH) - OP<int>(src_zp, WGH) // qnn.conv2d
%2 = %1 + orig_bias // bias
%2 = (%1 - rq_in_zp) * rq_in_scl / rq_out_scl + rq_out_zp // qnn.requantize
%3 = act(%2) // activation == clip
transform to DNNL compatible:
%1 = OP<int>(SRC, WGH)
%2 = cast(%1, dtype="float")
%2 = (%1 + bias) * o_scl
%3 = act(%2) * act_scl
%4 = %3 + SRC2 * sum_scl
%5 = %4 + dst_zp
%6 = cast(%5, dtype="float")
where:
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
bias = orig_bias - OP(src_zp, WGH) - rq_in_zp + rq_out_zp * rq_out_scl / rq_in_scl
dst_zp = sum_out_zp - sum_lhs_zp * sum_lhs_scl / sum_out_scl -
sum_rhs_zp * sum_rhs_scl / sum_out_scl
"""

def __init__(self):
super(LegalizeQnnOpForHeadsail, self).__init__()
self.src = wildcard()
self.wgh = wildcard()
self.bias = wildcard()
self.sum_src = wildcard()

self.src_scl = is_constant()
self.src_zp = is_constant()
self.wgh_scl = is_constant()
self.wgh_zp = is_expr(const(0))

self.rq_in_scl = is_constant()
self.rq_in_zp = is_constant()
self.rq_out_scl = is_constant()
self.rq_out_zp = is_constant()

self.sum_lhs_scl = is_constant()
self.sum_lhs_zp = is_constant()
self.sum_rhs_scl = is_constant()
self.sum_rhs_zp = is_constant()
self.sum_out_scl = is_constant()
self.sum_out_zp = is_constant()

self.root = (is_op("qnn.conv2d") | is_op("qnn.dense"))(
self.src, self.wgh, self.src_zp, self.wgh_zp, self.src_scl, self.wgh_scl
)
pat = is_op("add")(self.root, self.bias) | self.root # optional bias
pat = is_op("qnn.requantize")(
pat, self.rq_in_scl, self.rq_in_zp, self.rq_out_scl, self.rq_out_zp
)
pat = is_op("clip")(pat)
cast = is_op("cast")(pat)
pat = is_op("qnn.add")(
cast,
self.sum_src,
self.sum_lhs_scl,
self.sum_lhs_zp,
self.sum_rhs_scl,
self.sum_rhs_zp,
self.sum_out_scl,
self.sum_out_zp,
)
pat = is_op("clip")(pat)
self.pattern = pat | cast

def callback(self, pre, post, node_map):
root = node_map[self.root][0]
src = node_map[self.src][0]
wgh = node_map[self.wgh][0]
bias = node_map.get(self.bias, default=[relay.const(0, dtype="int32")])[0]
src_zp = node_map[self.src_zp][0]
rq_in_scl = node_map[self.rq_in_scl][0]
rq_in_zp = node_map[self.rq_in_zp][0]
rq_out_scl = node_map[self.rq_out_scl][0]
rq_out_zp = node_map[self.rq_out_zp][0]

final_dtype = node_map[self.pattern][0].checked_type.dtype

if root.op == relay.op.get("qnn.conv2d"):
dst_layout = root.attrs.out_layout
dst_layout = root.attrs.data_layout if dst_layout == "" else dst_layout
wgh_layout = root.attrs.kernel_layout
else:
# qnn.dense has no layout attributes. Assume that is plain
dst_layout = "NC"
wgh_layout = "OI"

# TODO(@apeskov): dst_layout may ne blocked
bias_rank = len(dst_layout) - dst_layout.index("C")

sum_src = node_map[self.sum_src][0] if self.sum_src in node_map else None
# Default values if qnn.sum is not present
sum_lhs_scl = node_map[self.sum_lhs_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_lhs_zp = node_map[self.sum_lhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_rhs_scl = node_map[self.sum_rhs_scl][0] if sum_src else relay.const(0, dtype="float32")
sum_rhs_zp = node_map[self.sum_rhs_zp][0] if sum_src else relay.const(0, dtype="int32")
sum_out_scl = node_map[self.sum_out_scl][0] if sum_src else relay.const(1, dtype="float32")
sum_out_zp = node_map[self.sum_out_zp][0] if sum_src else relay.const(0, dtype="int32")

def cast_fp(op):
return relay.op.cast(op, dtype="float32")

# recalculate some factors
o_scl = rq_in_scl / rq_out_scl
act_scl = sum_lhs_scl / sum_out_scl
sum_scl = sum_rhs_scl / sum_out_scl
dst_zp = (
cast_fp(sum_out_zp)
- cast_fp(sum_lhs_zp) * sum_lhs_scl / sum_out_scl
- cast_fp(sum_rhs_zp) * sum_rhs_scl / sum_out_scl
)
bias = self.squeeze_bias(bias, dst_layout)
bias = (
cast_fp(bias)
- cast_fp(self.fake_op(src_zp, wgh, wgh_layout))
- cast_fp(rq_in_zp)
+ cast_fp(rq_out_zp) * rq_out_scl / rq_in_scl
)
bias = self.broadcast_to_rank(bias, bias_rank)

zero_zp = relay.const(0, dtype="int32")
one_scl = relay.const(1.0, dtype="float32")

# construct new graph with proper post op ordering
gr = tvm.relay.Call(
root.op,
[src, wgh, zero_zp, zero_zp, one_scl, one_scl],
root.attrs,
root.type_args,
root.span,
)
gr = relay.op.cast(gr, dtype="float32")
gr = gr + bias
gr = gr * o_scl
gr = relay.op.clip(gr, 0, 255) * act_scl
gr = gr + sum_scl * cast_fp(sum_src) if sum_src else gr
gr = gr + dst_zp
gr = relay.op.cast(gr, dtype=final_dtype)
return gr

@staticmethod
def fake_op(zp, wgh, layout):
"""Fake operator implementation for zp broadcast input"""
# Conv: reduce kernel {OC, IC, KH, KW} -> {OC} in case of group that is still correct
# Dense: reduce kernel {OC, IC} -> {OC}
wgh_int = relay.op.cast(wgh, dtype="int32")
reduced_kernel = relay.op.sum(
wgh_int, axis=[layout.index("O")], keepdims=False, exclude=True
)
return zp * reduced_kernel

@staticmethod
def squeeze_bias(bias, layout):
shape = transform.InferTypeLocal(bias).concrete_shape
c_position = layout.index("C") - len(layout) + len(shape)
squeeze_idxs = [i for i in range(len(shape)) if i != c_position]
return relay.op.squeeze(bias, squeeze_idxs)

@staticmethod
def broadcast_to_rank(op, rank):
"""Scalar or 1D tensor are supported"""
shape = transform.InferTypeLocal(op).concrete_shape
if len(shape) == 0:
return op
if len(shape) == 1:
return relay.op.expand_dims(op, 1, rank - 1)
raise ValueError("Unexpected bias rank to broadcast. Only 0 and 1 are supported.")


def legalize_qnn_for_headsail(mod):
"""Transform qnn primitives to DNNL compatible form. Eliminate source zero point and apply
strict sequence of post ops."""
print("Legalizing qnn for headsail")
#mod["main"] = rewrite(LegalizeQnnOpForHeadsail(), mod["main"])

seq = tvm.transform.Sequential(
[
transform.InferType(),
# transform.SimplifyInference(), # TODO: this pass decompose nn.layer_norm
# transform.FoldScaleAxis(), # TODO: fail inside TVM in case of grouped convolutions.
transform.FoldConstant(),
]
)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod
Loading

0 comments on commit d24af7d

Please sign in to comment.