Skip to content

Commit

Permalink
Block-wise tuning for PyTorch model (#818)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Kaihui-intel <[email protected]>
  • Loading branch information
yiliu30 and Kaihui-intel authored May 12, 2023
1 parent 42f0816 commit 9c26ed7
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ scikit-learn
Keras-Preprocessing
onnx
onnxruntime
transformers >= 4.16.0
transformers>=4.16.0
torch>=1.9.0
6 changes: 3 additions & 3 deletions neural_compressor/adaptor/onnxrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,13 +1014,13 @@ def query_fw_capability(self, model):
index + block_len - 1 < len(attention_matmul):
ffn_matmul.append([attention_matmul[index + block_len - 2],
attention_matmul[index + block_len - 1]])
block_info = []
block_wise = []
for block in reversed(ffn_matmul):
node_info = []
for node in block:
node_info.append((node.name, node.op_type))
if len(node_info) != 0:
block_info.append(node_info)
block_wise.append(node_info)

for _, node in enumerate(self.pre_optimized_model.nodes()):
# for TRT EP, only insert Q/DQ to inputs of Add nodes followed by ReduceMean
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def query_fw_capability(self, model):
op_wise.update(
{(node.name, node.op_type): copy.deepcopy(optype_wise[node.op_type])})

return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': recipes_ops, 'block_info': block_info}
return {'optypewise': optype_wise, 'opwise': op_wise, 'recipes_ops': recipes_ops, 'block_wise': block_wise}

def _optypewise_filter_for_qdq(self, optype_wise):
"""Filter optypes that don't support per_channel in QDQ format.
Expand Down
58 changes: 48 additions & 10 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,18 +1046,18 @@ def _get_quantizable_ops(self, model):
tmp_model = model
tmp_model.eval()
quantizable_ops = []
self.block_wise =[]
self._get_quantizable_ops_recursively(tmp_model, '', quantizable_ops)
# capability = self.query_handler.get_quantization_capability()['dynamic'] \
# if self.approach == "post_training_dynamic_quant" else \
# self.query_handler.get_quantization_capability()['quant_aware'] \
# if self.approach == "quant_aware_training" else \
# self.query_handler.get_quantization_capability()['static']

q_capability = {}
q_capability['block_wise'] = None
q_capability['optypewise'] = OrderedDict()
q_capability['opwise'] = OrderedDict()
# add block ops
if self.block_wise:
logger.debug(f"*** Found {len(self.block_wise)} blocks: {self.block_wise}")
q_capability['block_wise'] = self.block_wise[::-1] if self.block_wise else None

quant_datatypes = self.query_handler.get_quant_datatypes()

if self.approach == "quant_aware_training":
capability_pair = [(self.query_handler.get_quantization_capability()['quant_aware'], 'static')]
fp32_config = {'activation': {'dtype': 'fp32'}, 'weight': {'dtype': 'fp32'}}
Expand Down Expand Up @@ -2948,7 +2948,15 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
Returns:
None
"""


# group ops by postion for transform-based model
from .torch_utils.pattern_detector import TransformerBasedModelBlockPatternDetector
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
attention_block = detect_result.get("attention_blocks", None)
ffn_blocks = detect_result.get("ffn_blocks", None)
logger.info(f"Attention Blocks: {len(attention_block)}")
logger.info(f"FFN Blocks: {len(ffn_blocks)}")
if not os.path.exists(self.ipex_config_path):
assert isinstance(model, torch.nn.Module), \
"The model passed in is not the instance of torch.nn.Module"
Expand Down Expand Up @@ -3040,7 +3048,8 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
del tmp_model
import gc
gc.collect()

map_op_name_to_fqn = {}

with open(self.ipex_config_path, 'r') as f:
self.cfgs = json.load(f)
if self.version.release < Version("1.12.0").release: # pragma: no cover
Expand Down Expand Up @@ -3071,27 +3080,45 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
module_key = name[0][0]
op_cfg_id = name[0][2]
ipex_op_type = self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type']
module_fqn = self.cfgs[module_key]['q_op_infos'][op_cfg_id].get('fqn', None)

if ipex_op_type in unify_op_type_mapping_ipex:
quantizable_ops.append((tuple(name),
unify_op_type_mapping_ipex[ipex_op_type]))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else:
re_flag = False
for pattern, unify_op_type in unify_op_type_mapping_ipex['re'].items():
if re.match(pattern, ipex_op_type):
re_flag = True
quantizable_ops.append((tuple(name), unify_op_type))
map_op_name_to_fqn[(tuple(name), unify_op_type)] = module_fqn
break
if not re_flag:
quantizable_ops.append((tuple(name), ipex_op_type))
map_op_name_to_fqn[(tuple(name), ipex_op_type)] = module_fqn
else:
op_type = ""
for op_name in name:
module_key = op_name[0]
op_cfg_id = op_name[2]
op_type += self.cfgs[module_key]['q_op_infos'][op_cfg_id]['op_type']
quantizable_ops.append((tuple(name), op_type))
_module_key = name[0][0]
_op_cfg_id = name[0][2]
module_fqn = self.cfgs[_module_key]['q_op_infos'][_op_cfg_id]['fqn']
map_op_name_to_fqn[(tuple(name), op_type)] = module_fqn
self.op_infos_from_cfgs = op_infos_from_cfgs
self.output_tensor_id_op_name = output_tensor_id_op_name
logger.debug("Map op name to fqn: ")
logger.debug(map_op_name_to_fqn)
logger.info("Attention Blocks : ")
logger.info(attention_block)
logger.info("FFN Blocks : ")
logger.info(ffn_blocks)
self.block_wise = ffn_blocks


os.remove(self.ipex_config_path)

def get_fuse_ops(self, default_cfgs):
Expand Down Expand Up @@ -3887,14 +3914,22 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
Returns:
None
"""
from .torch_utils.pattern_detector import TransformerBasedModelBlockPatternDetector
from .torch_utils.util import get_op_type_by_name
detector = TransformerBasedModelBlockPatternDetector(model)
detect_result = detector.detect_block()
attention_block = detect_result.get("attention_blocks", None)
ffn_blocks = detect_result.get("ffn_blocks", None)
logger.info(f"Attention Blocks: {len(attention_block)}")
logger.info(f"FFN Blocks: {len(ffn_blocks)}")
module_dict = dict(model.named_modules())
for op_name, child in model.named_modules():
if self.is_fused_module(child):
for name, _ in child.named_children():
module_prefix = op_name + '.' + name
if module_prefix in module_dict:
module_dict.pop(module_prefix) # remove sub-modules of fused modules

q_ops_set = set()
for op_name, child in module_dict.items():
if type(child) in self.white_list \
and type(child) != torch.nn.Sequential \
Expand All @@ -3903,6 +3938,9 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
(op_name, unify_op_type_mapping[str(child.__class__.__name__)]
if str(child.__class__.__name__) in unify_op_type_mapping else str(
child.__class__.__name__)))
q_ops_set.add(op_name)
block_wise = [[(name, get_op_type_by_name(name, quantizable_ops)) for name in block] for block in ffn_blocks]
self.block_wise = block_wise

def _get_module_scale_zeropoint(self, model, tune_cfg, prefix=''):
"""get activation scale and zero_point for converted module.
Expand Down
164 changes: 164 additions & 0 deletions neural_compressor/adaptor/torch_utils/pattern_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Block detector for Transformer-based model."""

from ...utils.utility import LazyImport
torch = LazyImport("torch")
from typing import Dict, List, Union

from ...utils import logger
from .util import get_depth, get_dict_at_depth, get_element_under_depth

BLOCK_PATTERNS = [
# [['OP_TYPE1', NUM_OPS], ['OP_TYPE2', NUM_OPS], ...]
[['Linear', 4], ['Linear', 4]], # TODO add model name
[['Linear', 2], ['Linear', 2]], # TODO add model name
[['Conv1D', 2], ['Conv1D', 2]], # GPT-2
[['Linear', 4], ['Linear', 3]], # Llama
[['Linear', 4], ['Linear', 2]], # T5-Encoder, OPT
[['Linear', 4], ['Linear', 1], ['Linear', 1]], # Bert
[['Linear', 4], ['Linear', 4], ['Linear', 2]], # T5-Decoder
]


class TransformerBasedModelBlockPatternDetector:
"""Detect the attention block and FFN block in transformer-based model."""

def __init__(self, model: torch.nn.Module, pattern_lst: List[List[Union[str, int]]] = BLOCK_PATTERNS) -> None:
"""Init the block detector.
Args:
model: the model to be detected.
pattern_lst: block pattern list.
"""
self.model = model
self.pattern_lst = pattern_lst
self.pos_info = None

def detect_block(self) -> Dict[str, List[List[str]]]:
"""
Traverse the model definition and return the attention blocks and ffn blocks.
Returns:
blocks: A dict include the detected attention blocks and ffn blocks.
"""
# Step 1: Traverse model definition and record the op position
if not self.pos_info:
pos_info = {0: {}}
self.traverse_model(self.model, result=pos_info)
self.pos_info = pos_info
# Step 2: Traverse all blocks in different depths and record the blocks that matched the pattern
detect_result = []
for pattern in self.pattern_lst:
_, result = self._search_pattern(pos_info, pattern)
if result:
detect_result.append((result, pattern))
# Step 3: Get the attention blocks and ffn blocks
blocks = {"attention_blocks": None, "ffn_blocks": None}
blocks["attention_blocks"], blocks["ffn_blocks"] = self._group_block(detect_result)
logger.info(f'FFN BLOCKS: {blocks["ffn_blocks"]}')
logger.info(f'Attention BLOCKS: {blocks["attention_blocks"]}')
return blocks

@staticmethod
def traverse_model(model, prefix="", depth=1, result=None, key = 0):
"""Traverse the pytorch model according to its hierarchical structure.
Args:
model: input model to be traversed.
prefix: prefix of module. Defaults to "".
depth: current traverse depth. Defaults to 1.
result: depth and its included ops. Defaults to {0: {}}.
key: current root key. Defaults to 0.
"""
module_lst =list(model.named_children())
if len(module_lst) == 0:
# layer name: 'encoder.layer.7.attention.self.query'
# model repr: Linear(in_features=768, out_features=768, bias=True)
# class name: 'Linear'
result[key] = (prefix, model, model.__class__.__name__)
for i, (name, sub_module) in enumerate(module_lst, 1):
indent = " "*depth
new_name = prefix + '.' + name if prefix != "" else name
model_type = sub_module.__class__.__name__
logger.debug( f"Depth: [{depth}]" + indent + f"[{model_type}]{ new_name}")
sub_key = (depth, i, model_type)
if sub_key not in result[key]:
result[key][sub_key] = dict()
TransformerBasedModelBlockPatternDetector.traverse_model(sub_module, prefix=new_name, \
depth=depth+1, result=result[key], key = sub_key)

@staticmethod
def _search_pattern(pos_info: Dict, pattern: List[List[Union[str, int]]]) -> List[List[str]]:
"""Search all blocks that matched the pattern.
Args:
pos_info: the position information of ops.
pattern: block pattern.
Returns:
The number of matched blocks and the matched blocks.
"""
max_depth = get_depth(pos_info)
matched_cnt = 0
result = []
for depth in range(max_depth, -1, -1):
attention_depth = depth
depth_block_lst = []
get_dict_at_depth(pos_info, attention_depth, depth_block_lst, 0)
target_op_types = set(pair[0] for pair in pattern)
for i, block in enumerate(depth_block_lst):
sub_block_lst = []
get_dict_at_depth(block, 1, sub_block_lst, 0)
block_pattern = []
block_result = []
for sub_block in sub_block_lst:
ops_lst = []
get_element_under_depth(sub_block, ops_lst)
filter_ops = [op for op in ops_lst if op[2] in target_op_types]
if len(filter_ops) > 0:
sub_block_pattern = [filter_ops[0][2], len(filter_ops)]
block_pattern.append(sub_block_pattern)
ops_name = [op[0] for op in filter_ops]
block_result.append(ops_name)
if block_pattern == pattern:
matched_cnt += 1
logger.info(f"[DEPTH] {depth} [BLOCK] {i}, Found block match pattern {pattern}!!")
logger.info(f"[Block keys] {block.keys()}")
logger.info(f"[Block Ops] { [pair[0] for pair in ops_lst if pair[2] in target_op_types]}")
result.append(block_result)
if matched_cnt > 0:
logger.info(f" Found {matched_cnt} blocks")
return matched_cnt, result

@staticmethod
def _group_block(detect_result):
"""Collect attention and ffn blocks from detect result."""
import itertools
ffn_block_lst = []
attention_block_lst = []
for block_lst, pattern in detect_result:
for block in block_lst:
# Group the first block as attention blocks and
# the remaining blocks belong to ffn block.
if block:
attention_block_lst.append(block[0])
ffn_block = list(itertools.chain(*block[1:]))
if ffn_block:
ffn_block_lst.append(ffn_block)
return attention_block_lst, ffn_block_lst
31 changes: 30 additions & 1 deletion neural_compressor/adaptor/torch_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,4 +1001,33 @@ def calculate_quant_min_max(unsigned, num_bits):
else:
quant_min, quant_max = -1 * 2.0**(num_bits - 1), 2.0**(num_bits - 1) - 1
return quant_min, quant_max


def get_depth(d) -> int:
"""Query the depth of the dict."""
if isinstance(d, dict):
return 1 + max(get_depth(v) for v in d.values())
return 0

def get_dict_at_depth(d, target_depth, result, depth=0):
"""Get all sub-dicts that are at a specified depth in a nested dict."""
if depth == target_depth:
result.append(d)
return
elif depth < target_depth and isinstance(d, dict):
for k, v in d.items():
get_dict_at_depth(v, target_depth, result, depth=depth+1)

def get_element_under_depth(d, ops_lst):
"""Get all values in a nested dict."""
if isinstance(d, dict):
for k, v in d.items():
get_element_under_depth(v, ops_lst)
else:
ops_lst.append(d)

def get_op_type_by_name(op_name, quantizable_ops):
"""Get op type by op name."""
for pair in quantizable_ops:
if pair[0] == op_name:
return pair[1]
return None
3 changes: 1 addition & 2 deletions neural_compressor/strategy/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def fallback_by_block(self, fallback_items_lst, best_op_tuning_cfg_stage1, targe
dict: op_tuning_cfg fall-backed by block
"""
from copy import deepcopy
op_block_lst = self.capability.get('block_info', [])
op_block_lst = self.capability.get('block_wise', [])
if op_block_lst:
# Fallback block by block
fallback_items_name_lst = [item.name for item in fallback_items_lst]
Expand All @@ -192,7 +192,6 @@ def fallback_by_block(self, fallback_items_lst, best_op_tuning_cfg_stage1, targe
op_block_fallback_lst.append(op_block)

initial_op_tuning_cfg = deepcopy(best_op_tuning_cfg_stage1)

# Fallback by accumulating blocks
if op_block_fallback_lst:
logger.info(f"Start to fallback op to {target_dtype} by blocks")
Expand Down
2 changes: 0 additions & 2 deletions neural_compressor/strategy/utils/tuning_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def __iter__(self):

class BlockFallbackTuningSampler(TuningSampler):
"""Not displayed in API Docs."""

def __init__(self,
tuning_space: TuningSpace,
tuning_order_lst: List[TuningOrder],
Expand Down Expand Up @@ -484,4 +483,3 @@ def __iter__(self):
logger.debug(f"[BlockFallbackTuningSampler] updated_tuning_cfg {op_name_type}: {new_op_config}")
logger.debug(f"[BlockFallbackTuningSampler] fallback {op_name_type} to {self.target_dtype}")
yield new_tune_cfg

Loading

0 comments on commit 9c26ed7

Please sign in to comment.