From 3b2925263845ea6e38df5b50912fbfd4bd6b85ee Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Wed, 22 Nov 2023 23:09:31 +0800 Subject: [PATCH] Support Quantization of Big Saved Model for TF Backend (#1396) Signed-off-by: zehao-intel --- neural_compressor/adaptor/tensorflow.py | 98 +++++- .../adaptor/tf_utils/graph_converter.py | 34 +++ .../graph_rewriter/qdq/insert_qdq_pattern.py | 20 ++ .../adaptor/tf_utils/graph_util.py | 27 +- .../tf_utils/smooth_quant_calibration.py | 281 +++++++++++++++++- .../adaptor/tf_utils/smooth_quant_scaler.py | 113 +++++++ neural_compressor/adaptor/tf_utils/util.py | 168 ++++++++++- neural_compressor/model/model.py | 10 +- neural_compressor/model/tensorflow_model.py | 225 +++++++++++++- test/tfnewapi/test_big_saved_model.py | 143 +++++++++ 10 files changed, 1095 insertions(+), 24 deletions(-) create mode 100644 test/tfnewapi/test_big_saved_model.py diff --git a/neural_compressor/adaptor/tensorflow.py b/neural_compressor/adaptor/tensorflow.py index 2f2d2653228..5ca1950bf39 100644 --- a/neural_compressor/adaptor/tensorflow.py +++ b/neural_compressor/adaptor/tensorflow.py @@ -648,6 +648,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=data_loader, + calib_func=q_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -670,6 +671,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=data_loader, + calib_func=q_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -693,6 +695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=data_loader, + calib_func=q_func, qdq_enabled=self.qdq_enabled, new_api=self.new_api, performance_only=self.performance_only, @@ -761,15 +764,15 @@ def _dump_model_op_stats(self, model_graphdef): if i.op in fp32_op_list: if "T" not in i.attr and i.op != "Cast": continue - if i.attr["T"].type == dtypes.bfloat16: - res[i.op]["BF16"] += 1 - elif i.attr["T"].type in (dtypes.quint8, dtypes.qint8): - res[i.op]["INT8"] += 1 - elif i.op == "Cast": + if i.op == "Cast": if i.attr["DstT"].type == dtypes.bfloat16: res[i.op]["BF16"] += 1 elif i.attr["DstT"].type == dtypes.float32: res[i.op]["FP32"] += 1 + elif i.attr["T"].type == dtypes.bfloat16: + res[i.op]["BF16"] += 1 + elif i.attr["T"].type in (dtypes.quint8, dtypes.qint8): + res[i.op]["INT8"] += 1 else: res[i.op]["FP32"] += 1 @@ -1815,7 +1818,6 @@ def smooth_quant( model, dataloader, calib_iter=1, - tune_cfg=None, alpha=0.5, folding=False, percentile=99.999, @@ -1832,7 +1834,6 @@ def smooth_quant( model: original model dataloader: the calibration dataloader calib_iter: how many steps of iterations on the dataloader to move forward - tune_cfg: quantization config alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant percentile: percentile of calibration to remove outliers @@ -1852,6 +1853,11 @@ def smooth_quant( if self.smooth_quant_model is not None: return self.smooth_quant_model + if model.model_type == "llm_saved_model": + return self.smooth_quant_LLM( + model, dataloader, calib_iter, alpha, folding, percentile, op_types, scales_per_op + ) + # Do a pre-optimization before smooth quant from .tf_utils.graph_rewriter.generic.pre_optimize import PreOptimization @@ -1860,6 +1866,7 @@ def smooth_quant( model.graph_def = self.pre_optimized_model.graph_def # Get the nodes list which can't be quantized from tune_cfg + tune_cfg = None black_nodes = [] if tune_cfg is not None: self._tuning_cfg_to_fw(tune_cfg) @@ -1887,6 +1894,81 @@ def smooth_quant( self.smooth_quant_model = model return self.smooth_quant_model + def smooth_quant_LLM( + self, + model, + dataloader, + calib_iter=1, + alpha=0.5, + folding=False, + percentile=99.999, + op_types=["MatMul", "Conv2D"], + scales_per_op=True, + ): + """Convert the model by smooth quant. + + Args: + model: original model of TensorflowLLMModel object. + calib_iter: how many steps of iterations on the dataloader to move forward. + tune_cfg: quantization config. + alpha: smooth alpha in SmoothQuant, 1.0 will fallback to SPIQ. + folding: whether insert mul(False) or just allow foldable layers(True) for SmoothQuant. + percentile: percentile of calibration to remove outliers. + op_types: The op types whose input tensor will be dumped. + scales_per_op: True, each op will have an individual scale, mainly for accuracy. + False, ops with the same input will share a scale, mainly for performance. + + Returns: + model: A smoothed Tensorflow model. + """ + # Do a pre-optimization before smooth quant + from .tf_utils.graph_rewriter.generic.pre_optimize import PreOptimization + + self.pre_optimizer_handle = PreOptimization(model, self.new_api, self.device) + self.pre_optimized_model = self.pre_optimizer_handle.get_optimized_model(self.itex_mode) + model.graph_def = self.pre_optimized_model.graph_def + + # Get the nodes list which can't be quantized from tune_cfg + tune_cfg = None + black_nodes = [] + if tune_cfg is not None: + self._tuning_cfg_to_fw(tune_cfg) + black_nodes = [node for node in self.quantize_config if self.quantize_config[node] == "fp32"] + + # only support per-tensor MatMul now + op_types = ["MatMul"] + llm_temp_dir = self.work_dir + "/temp_saved_model" + # Run calibration to get max values per channel + from .tf_utils.smooth_quant_calibration import SmoothQuantCalibrationLLM + + calibration = SmoothQuantCalibrationLLM( + model._model, + dataloader, + calib_iter, + op_types, + percentile, + black_nodes, + llm_temp_dir, + model.weight_name_mapping, + ) + max_vals_per_channel, sq_target_node_names, sq_weight_tensor_dict, sq_graph_def = calibration( + model.input_node_names, model.output_node_names + ) + + # Calculate the smooth quant scaler and insert Mul op into the graph + from .tf_utils.smooth_quant_scaler import SmoothQuantScalerLLM + + scaler = SmoothQuantScalerLLM(sq_graph_def, alpha, scales_per_op, op_types) + sq_graph_def, sq_weight_scale_dict, mul_list = scaler.transform( + max_vals_per_channel, sq_weight_tensor_dict, sq_target_node_names + ) + model.graph_def = sq_graph_def + model.model_path = llm_temp_dir + model.sq_weight_scale_dict = sq_weight_scale_dict + self.smooth_quant_mul_ops.extend(mul_list) + self.smooth_quant_model = model + return self.smooth_quant_model + @adaptor_registry class Tensorflow_ITEXAdaptor(TensorFlowAdaptor): @@ -1945,6 +2027,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=data_loader, + calib_func=q_func, itex_mode=self.itex_mode, qdq_enabled=self.qdq_enabled, new_api=self.new_api, @@ -1992,6 +2075,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None): fp32_ops=self.fp32_ops, bf16_ops=self.bf16_ops, data_loader=data_loader, + calib_func=q_func, itex_mode=self.itex_mode, qdq_enabled=self.qdq_enabled, new_api=self.new_api, diff --git a/neural_compressor/adaptor/tf_utils/graph_converter.py b/neural_compressor/adaptor/tf_utils/graph_converter.py index c237293818d..791ae8e071b 100644 --- a/neural_compressor/adaptor/tf_utils/graph_converter.py +++ b/neural_compressor/adaptor/tf_utils/graph_converter.py @@ -102,6 +102,7 @@ def __init__( fp32_ops=[], bf16_ops=[], data_loader=None, + calib_func=None, fake_quant=False, itex_mode=False, qdq_enabled=False, @@ -116,6 +117,7 @@ def __init__( :param fp32_ops: fall back to fp32 dtype op list :param bf16_ops: fall back to bf16 dtype op list :param data_loader: for calibration phase used dataloader + :param calib_func: for calibration phase used function :param fake_quant: for quantization-aware training model conversion to default model """ self.model = model @@ -139,6 +141,7 @@ def __init__( self._calibration_data = [] self._fp32_print_data = [] self.data_loader = data_loader + self.calib_func = calib_func self._check_tf_version() self._check_args() @@ -157,6 +160,7 @@ def __init__( self._gen_tmp_filenames() self._kl_op_dict = {} self._kl_keys = [] + self._llm_weight_minmax = {} self._print_node_mapping = {} self._enable_kl_op_names = [k for k in self.op_wise_config if self.op_wise_config[k][1] == "kl"] self.scale_info = {} @@ -193,6 +197,14 @@ def _inference(self, model): Args: model(TensorflowBaseModel): input TensorflowBaseModel """ + if self.calib_func: + self.calib_func(model.model) + return + + if model.model_type == "llm_saved_model": + self._inference_llm(model) + return + # ITEX optimization has broken INC calibration process. # INC needs turn off ITEX optimization pass in calibration stage. # TODO ITEX will provide API to replace setting environment variable. @@ -281,6 +293,24 @@ def check_shape(tensor, data): break os.environ["ITEX_REMAPPER"] = "1" + def _inference_llm(self, model): + input_tensor_names = model.input_tensor_names + auto_trackable = model.model + infer = auto_trackable.signatures["serving_default"] + for idx, (inputs, _) in enumerate(self.data_loader): + feed_dict = {} + if len(input_tensor_names) == 1: + feed_dict[input_tensor_names[0]] = inputs + else: + assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor" + for i, input_tensor_name in enumerate(input_tensor_names): + feed_dict[input_tensor_name] = inputs[i] + + _ = infer(**feed_dict) + + if idx >= self.calib_iteration: + break + def _check_tf_version(self): """Check if the installed tensorflow version is supported.""" is_supported_version = False @@ -849,6 +879,9 @@ def _insert_qdq_pairs(self): self._inference(self._sampling_model) self._calibration_data = Helper.gen_valid_sampling_log(tmp_dump_file) + if hasattr(self._sampling_model, "_weight_tensor_minmax_dict"): + self._llm_weight_minmax = self._sampling_model.weight_tensor_minmax_dict + del sampling_graph_def del output_tensor_names del self._sampling_model @@ -868,6 +901,7 @@ def _insert_qdq_pairs(self): self.device, self.performance_only, self.itex_mode, + self._llm_weight_minmax, ).do_transformation() def _convert_qdq(self): diff --git a/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py b/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py index b56563f9b13..a1d5e4dcd9b 100644 --- a/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py +++ b/neural_compressor/adaptor/tf_utils/graph_rewriter/qdq/insert_qdq_pattern.py @@ -46,6 +46,7 @@ def __init__( device, performance_only, itex_mode, + llm_weight_minmax, ): """Initialization.""" super().__init__(model) @@ -58,6 +59,7 @@ def __init__( self.device = device self.performance_only = performance_only self.itex_mode = itex_mode + self.llm_weight_minmax = llm_weight_minmax self.node_details = namedtuple("node_details", ["node", "output"]) self.node_name_mapping = {} self.check_op_list = { @@ -548,6 +550,24 @@ def _insert_qdq_pattern_for_weight_node( # qint8_tensor = np.clip(qint8_tensor, -127, 127).astype(np.int8) min_value = -range_value max_value = range_value + elif weight_node.op == "ReadVariableOp": + min_value = self.llm_weight_minmax[weight_node.name][0] + max_value = self.llm_weight_minmax[weight_node.name][1] + min_value *= range_coefficent + max_value *= range_coefficent + min_value = min(min_value, 0.0) + if min_value == max_value: + if abs(min_value) < 0.000001: + max_value = min_value + 1.0 + elif min_value > 0: + max_value = 2 * min_value + else: + max_value = min_value / 2.0 + range_value = np.max(np.abs([min_value, max_value])) + # qint8_tensor = (np.around(float_tensor * 127.0 / range_value)).astype(np.int8) + # qint8_tensor = np.clip(qint8_tensor, -127, 127).astype(np.int8) + min_value = -range_value + max_value = range_value elif host_op_type == "DepthwiseConv2dNative": float_tensor = tensor_util.MakeNdarray(weight_node.attr["value"].tensor) # get the max values based on dim 0 and 1 for depthwise conv diff --git a/neural_compressor/adaptor/tf_utils/graph_util.py b/neural_compressor/adaptor/tf_utils/graph_util.py index 8da165ab386..827d5521b9d 100644 --- a/neural_compressor/adaptor/tf_utils/graph_util.py +++ b/neural_compressor/adaptor/tf_utils/graph_util.py @@ -1044,8 +1044,33 @@ def gen_per_iter(data): res.append(mixed_str) return res + def separate(line): + """This function is to separate the strings. + + Example: + ';slice__print__;__max:[1];slice__print__;__min:[-1]' --> + [';slice__print__;__max:[1]', ';slice__print__;__min:[-1]'] + """ + separated_lines = [] + for subline in line.split("];"): + if not subline.startswith(";"): + subline = ";" + subline + if not subline.endswith("]"): + subline += "]" + separated_lines.append(subline) + return separated_lines + with open(log_path) as f: - valid_data = [i.strip() for i in f.readlines() if i.startswith(";")] + valid_data = [] + for i in f.readlines(): + if not i.startswith(";"): + continue + line = i.strip() + if line.find("];") != 0: + separated_lines = separate(line) + valid_data += separated_lines + else: + valid_data.append(line) first_line = valid_data[0].rsplit(":")[0] diff --git a/neural_compressor/adaptor/tf_utils/smooth_quant_calibration.py b/neural_compressor/adaptor/tf_utils/smooth_quant_calibration.py index a29863ece5a..4ce2149cc3b 100644 --- a/neural_compressor/adaptor/tf_utils/smooth_quant_calibration.py +++ b/neural_compressor/adaptor/tf_utils/smooth_quant_calibration.py @@ -16,16 +16,26 @@ # limitations under the License. """Tensorflow model calibration process for Smooth Quantization.""" +import copy import logging import os +import tempfile +import time from collections import OrderedDict, UserDict import numpy as np -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.framework import tensor_util +import tensorflow as tf +from tensorflow.core.framework import attr_value_pb2, graph_pb2 +from tensorflow.python.framework import dtypes, tensor_util +from tensorflow.python.saved_model import load, tag_constants +from neural_compressor import Model +from neural_compressor.utils.utility import CaptureOutputToFile + +from .graph_util import GraphAnalyzer +from .graph_util import GraphRewriterHelper as Helper from .quantize_graph_common import QuantizeGraphHelper -from .util import iterator_sess_run +from .util import iterator_sess_run, parse_saved_model, reconstruct_saved_model logger = logging.getLogger("neural_compressor") debug = bool(logger.level == logging.DEBUG) @@ -213,7 +223,7 @@ def __call__(self): Returns: max_vals_per_channel (dict): A dictionary containing the maximum values per channel. - shape_infos (dict): A dictionary containing the shape information. + sq_weight_node_names (dict): A dictionary mapping from weight names to target node names. """ self._generate_calibration_data() max_vals_per_channel = {} @@ -223,3 +233,266 @@ def __call__(self): ) max_vals_per_channel[key] = max_val_per_channel return max_vals_per_channel, self._sq_weight_node_names + + +class SmoothQuantCalibrationLLM(SmoothQuantCalibration): + """A class for performing smooth quantization calibration on a Tensorflow LLM model. + + Args: + model (str): A path to the original Tensorflow model. + iterations (int): The number of iterations to run the calibration process. + op_types (List[str]): The types of operations to be quantized. + percentile (float): The percentile of calibration to remove outliers. + black_nodes (List[str]): A list of node names to be ignored during calibration. + eval_func (function): The function to inference the model. + temp_path (str): The temporary path to store median model. + weight_name_mapping (): A function that convert weight tensor name in autotrackable to node name in graph_def + """ + + def __init__( + self, model_path, dataloader, iterations, op_types, percentile, black_nodes, temp_path, weight_name_mapping + ): + """Initializes a SmoothQuantCalibrationLLM object.""" + self.func = None + self.graph_def = None + self.frozen_func = None + self._saved_model = None + self.model = model_path + self.dataloader = dataloader + self.iterations = iterations + self.op_types = op_types + self.percentile = percentile + self.black_nodes = black_nodes + self.temp_path = temp_path + self.weight_name_mapping = weight_name_mapping + self.print_node_list = [] + self._sq_input_node_names = [] + self._sq_target_node_names = {} + self._sq_output_tensor_dict = {} + self._sq_weight_tensor_dict = {} + + def _parse_calibration_logs(self, tmp_dump_file): + """Parse calibration logs for llm saved_model.""" + valid_data = [] + with open(tmp_dump_file) as file: + for i in file.readlines(): + if i.startswith(";"): + valid_data.append(i.strip()) + + for activation in valid_data: + activation = activation.split(" ") + data = [] + activation_name = "" + per_channel = [] + for idx, s in enumerate(activation): + if idx == 0: + per_channel.append(float(s.rsplit(":")[-1].strip("["))) + activation_name = s.rsplit(":")[0][1:-9] + elif s.find("][") != -1: + pairs = [float(i) for i in s.split("][")] + per_channel.append(pairs[0]) + data.append(per_channel) + per_channel = [pairs[1]] + elif s.find("]]") != -1: + per_channel.append(float(s.strip("]"))) + data.append(per_channel) + else: + per_channel.append(float(s)) + + if activation_name not in self._sq_output_tensor_dict: + self._sq_output_tensor_dict[activation_name] = [np.array(data)] + else: + self._sq_output_tensor_dict[activation_name].append(np.array(data)) + + def _insert_print_for_activation(self, graph_def): + """Insert print node in the graph to do the calibration for llm saved_model.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = graph_def + + graph_info = cur_graph.parse_graph() + for cur_list in self.print_node_list: + pre_node_name = cur_list[0] + post_node_name = cur_list[-1] + insert_node_pairs = [] + top_node = graph_info[pre_node_name].node + if top_node.op == "ConcatV2": + for i in range(top_node.attr["N"].i): + insert_node_pairs.append([top_node.input[i], post_node_name]) + elif top_node.op in ("BatchMatMul", "BatchMatMulV2"): + insert_node_pairs.append([top_node.input[0], post_node_name]) + if graph_info[top_node.input[1]].node.op != "Const": + insert_node_pairs.append([top_node.input[1], post_node_name]) + elif top_node.op in ("Conv2DBackpropInput", "Conv3DBackpropInputV2"): + insert_node_pairs.append([top_node.input[2], post_node_name]) + else: + refresh_pre_node_name = graph_info[pre_node_name].node.input[0] + # Check the Conv2D could be fused with previous Pad or not. + # If so, we need to update the pre-node name correspondingly. + refresh_pre_node = graph_info[Helper.node_name_from_input(refresh_pre_node_name)].node + if refresh_pre_node.op == "Pad" and top_node.op in ("Conv2D", "Conv3D"): + insert_node_pairs.append([refresh_pre_node_name, post_node_name]) + refresh_pre_node_name = refresh_pre_node.input[0] + + insert_node_pairs.append([refresh_pre_node_name, post_node_name]) + + output_names = [] + for node_pair_names in insert_node_pairs: + for index, each_node_name in enumerate(node_pair_names): + name_with_sig = each_node_name + node_name_prefix = name_with_sig.replace(":", "__port__").replace("^", "__hat__") + print_node = Helper.create_node( + "Print", + node_name_prefix + "_print__{}".format(index), + [each_node_name + ":0", each_node_name + ":0"], + ) + + if index == 0: + msg = ";{}__print__:".format(each_node_name) + # workaround for swish_f32, attribute T is not in the op definition + if "swish_f32" in graph_info[pre_node_name].node.name: + src_dt = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum) + else: + src_dt = graph_info[pre_node_name].node.attr["T"] + else: + break + + print_node.attr["T"].CopyFrom(src_dt) + + print_node.attr["message"].s = msg.encode() + print_node.attr["first_n"].i = -1 + print_node.attr["summarize"].i = 102400000 + + attr_u = [dtypes.as_dtype(src_dt.type).as_datatype_enum] + print_node.attr["U"].list.CopyFrom(attr_value_pb2.AttrValue.ListValue(type=attr_u)) + post_node_names = graph_info[Helper.node_name_from_input(each_node_name)].outputs + if post_node_names: + for post_node_name in post_node_names: + post_node = graph_info[post_node_name].node + if each_node_name not in post_node.input: + continue + if ( + post_node.op == "FusedBatchNormV3" + and "_print_identity" + not in graph_info[Helper.node_name_from_input(post_node.name)].node.input[0] + ): + identity_node = Helper.create_node( + "Identity", + post_node.name + "_print_identity", + [graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]], + ) + identity_node.attr["T"].CopyFrom(src_dt) + cur_graph.add_node( + identity_node, + graph_info[Helper.node_name_from_input(post_node.name)].node.input[0], + [post_node.name], + ) + identity_node.input.append("^" + print_node.name) + else: + post_node.input.append("^" + print_node.name) + + cur_graph.add_node(print_node, each_node_name, []) + else: + identity_node1 = Helper.create_node( + "Identity", print_node.name + "_identity", [print_node.name] + ) + identity_node1.attr["T"].CopyFrom(src_dt) + cur_graph.add_node(print_node, each_node_name, [identity_node1.name]) + cur_graph.add_node(identity_node1, print_node.name, []) + output_names.append(identity_node1.name) + + return cur_graph.dump_graph() + + def evaluate(self, model): + """Evaluate function that inference the model to apply calibration. + + Args: + model (tf.python.training.tracking.tracking.AutoTrackable): The model to be evaluated. + The object is usually gotten by using tf.saved_model.load(model_dir) API. + + Returns: + accuracy (float): The accuracy result. + """ + input_tensor_names = model.input_tensor_names + auto_trackable = model.model + infer = auto_trackable.signatures["serving_default"] + for idx, (inputs, _) in enumerate(self.dataloader): + feed_dict = {} + if len(input_tensor_names) == 1: + feed_dict[input_tensor_names[0]] = inputs + else: + assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor" + for i, input_tensor_name in enumerate(input_tensor_names): + feed_dict[input_tensor_name] = inputs[i] + + _ = infer(**feed_dict) + + if idx >= self.iterations: + break + + def _inference(self, sampling_graph_def): + logger.info("Start sampling on calibration dataset for Smooth Quantization.") + # reconstruct graph_def that inserted print node to saved_model + reconstruct_saved_model(sampling_graph_def, self.func, self.frozen_func, self._saved_model, self.temp_path) + model = Model(self.temp_path, modelType="llm_saved_model") + self.evaluate(model) + + def _inference_for_calibration(self, model): + """Run the calibration on the input graph.""" + sampling_graph_def = self._insert_print_for_activation(model) + tmp_dump_file = tempfile.mkstemp(suffix=".log")[1] + with CaptureOutputToFile(tmp_dump_file): + self._inference(sampling_graph_def) + self._parse_calibration_logs(tmp_dump_file) + del sampling_graph_def + + def _get_weight_tensors(self): + model = load.load(self.model, [tag_constants.SERVING]) + for weight_tensor in model.variables: + parsed_name = self.weight_name_mapping(weight_tensor.name) + if parsed_name in self._sq_target_node_names: + self._sq_weight_tensor_dict[parsed_name] = weight_tensor.numpy() + + assert len(self._sq_weight_tensor_dict) == len( + self._sq_target_node_names + ), "Failed to get weights for some nodes, please check variables" + + def _generate_calibration_data(self, input_node_names, output_node_names): + """Generate the calibration data.""" + sorted_graph = QuantizeGraphHelper().get_sorted_graph( + self.graph_def, + input_node_names, + output_node_names, + ) + + for node in sorted_graph.node: + if node.op not in self.op_types or node.name in self.black_nodes: + continue + # Fix retval already been set issue + if "while" in node.input[0]: # pragma: no cover + continue + self._sq_input_node_names.append(node.input[0]) + self.print_node_list.append([node.name]) + self._sq_target_node_names[node.input[1]] = node.name + self._get_weight_tensors() + sampling_graph_def = copy.deepcopy(self.graph_def) + self._inference_for_calibration(sampling_graph_def) + + def __call__(self, input_node_names, output_node_names): + """Generates calibration data and calculate the maximum values per channel. + + Args: + input_node_names: (list): A list of names for input nodes. + output_node_names: (list): A list of names for output nodes. + + Returns: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel. + sq_target_node_names (dict): A dictionary mapping from weight names to target node names. + sq_weight_tensor_dict (dict): A dictionary containing tensor of weights. + """ + self.graph_def, self._saved_model, self.func, self.frozen_func, _, _ = parse_saved_model(self.model) + self._generate_calibration_data(input_node_names, output_node_names) + max_vals_per_channel = {} + for activation_name, output_tensor in self._sq_output_tensor_dict.items(): + max_val_per_channel = self._get_maxval_per_channel(output_tensor, percentile=self.percentile) + max_vals_per_channel[activation_name] = max_val_per_channel + return max_vals_per_channel, self._sq_target_node_names, self._sq_weight_tensor_dict, self.graph_def diff --git a/neural_compressor/adaptor/tf_utils/smooth_quant_scaler.py b/neural_compressor/adaptor/tf_utils/smooth_quant_scaler.py index be5152dd7e7..45e6e90e2d5 100644 --- a/neural_compressor/adaptor/tf_utils/smooth_quant_scaler.py +++ b/neural_compressor/adaptor/tf_utils/smooth_quant_scaler.py @@ -153,3 +153,116 @@ def transform(self, max_vals_per_channel, sq_weight_tensors, sq_weights_nodes, s sq_graph_def.library.CopyFrom(self.model.graph_def.library) self.model.graph_def = sq_graph_def return self.model, self.mul_list + + +class SmoothQuantScalerLLM(SmoothQuantScaler): + """A class for scaling model weights for TF LLM models using Smooth Quantization method. + + Args: + graph_def: graph_def of the model to be scaled + alpha: float, the scaling factor + scales_per_op: bool, each op will have an individual scale or + ops with the same input will share a scale + op_types: + """ + + def __init__(self, graph_def, alpha, scales_per_op, op_types): + """Initialization.""" + self.graph_def = graph_def + self.alpha = alpha + self.scales_per_op = scales_per_op + self.op_types = op_types + + self.graph_info = None + self.mul_list = [] + self.sq_weight_scale_dict = {} + + def _parse_weight_dict(self, max_vals_per_channel, sq_weight_tensor_dict): + """Parse weight related dictionaries to two required dictionaries. + + Args: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel. + sq_weight_tensor_dict (dict): A dictionary containing tensor of weights. + + Returns: + sq_weight_tensors: A dictionary whose structure is like {input_node_name: weight_tensor}}. + sq_weights_node_names: A dictionary whose structure is like {input_node_name: weight_node_name}}. + """ + sq_weight_tensors = {} + sq_weight_node_names = {} + for input_node_name in max_vals_per_channel: + curr_weight_tensors = [] + curr_weights_node_names = [] + next_node_names = self.graph_info[input_node_name].outputs + for node_name in next_node_names: + curr_node = self.graph_info[node_name].node + if curr_node.op not in self.op_types: + continue + if len(curr_node.input) >= 2: + weight_name = curr_node.input[1] + weight_tensor = sq_weight_tensor_dict[weight_name] + curr_weight_tensors.append(weight_tensor) + curr_weights_node_names.append(weight_name) + sq_weight_tensors[input_node_name] = curr_weight_tensors + sq_weight_node_names[input_node_name] = curr_weights_node_names + return sq_weight_tensors, sq_weight_node_names + + def transform(self, max_vals_per_channel, sq_weight_tensor_dict, sq_target_node_names): + """Apply scaling to weights and activations based on the maximum values per channel. + + Args: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel for each input node. + sq_weight_tensor_dict (dict): A dictionary whose structure is like {input_node_name: weight_tensor}. + sq_target_node_names (dict): A dictionary whose structure is like {weight_node_name: target_node_name}. + """ + self.g_analyzer = GraphAnalyzer() + self.g_analyzer.graph = self.graph_def + self.graph_info = self.g_analyzer.parse_graph() + sq_weight_tensors, sq_weight_node_names = self._parse_weight_dict(max_vals_per_channel, sq_weight_tensor_dict) + logger.info("Start scaling on model graph for Smooth Quantization.") + if self.scales_per_op: + # 1. obtain the smooth scale per op + # 2. adjust weight + # 3. adjust activation + for _, input_node_name in enumerate(max_vals_per_channel): + activation_max_per_in_channel = max_vals_per_channel[input_node_name] + W_lst = sq_weight_tensors[input_node_name] # VQK weight value + # Use the const nodes before to get weight values, VQK ReadVariable + W_node_name_lst = sq_weight_node_names[input_node_name] + # Get the concrete weight node as the output of Mul insertion, QKV ReadVariable + for w_i, W in enumerate(W_lst): + if len(W.shape) == 4: + # https://www.tensorflow.org/api_docs/python/tf/nn/conv2d + # weight: [filter_height, filter_width, in_channels, out_channels] + # activation: NHWC, also batch_shape + [in_height, in_width, in_channels] + tensor = np.abs(np.transpose(W, [0, 1, 3, 2])) + # reduce weight max to (in_channel, ), aligned with activation max + W_max_per_in_channel = np.max(np.reshape(tensor, (-1, tensor.shape[-1])), axis=0) + elif len(W.shape) == 2: # matmul + # reduce weight max to (in_channel, ), aligned with activation max + tensor = np.abs(W) + W_max_per_in_channel = np.max(tensor, axis=1) + else: # pragma: no cover + assert False, "not supported" + cur_weight_node_name = W_node_name_lst[w_i] + try: + scale = np.power(activation_max_per_in_channel, self.alpha) / np.power( + W_max_per_in_channel, (1 - self.alpha) + ) + except ValueError as e: # pragma: no cover + logger.info(e) + logger.info("Skip smoothing the node: {}".format(cur_weight_node_name)) + continue + # clip the scales that are too small + scale = np.clip(scale, a_min=1e-5, a_max=1e8) + # skip smoothing the op where scale has elements that less than 1 + # if np.any(scale < 1): + # logger.info("skip smooth quant: {}".format(input_node_name)) + # continue + self.sq_weight_scale_dict[cur_weight_node_name] = scale + self._adjust_activation(1 / scale, input_node_name, sq_target_node_names[cur_weight_node_name], w_i) + else: + pass + sq_graph_def = self.g_analyzer.dump_graph() + sq_graph_def.library.CopyFrom(self.graph_def.library) + return sq_graph_def, self.sq_weight_scale_dict, self.mul_list diff --git a/neural_compressor/adaptor/tf_utils/util.py b/neural_compressor/adaptor/tf_utils/util.py index a2de64ad3bf..eba40281d5c 100644 --- a/neural_compressor/adaptor/tf_utils/util.py +++ b/neural_compressor/adaptor/tf_utils/util.py @@ -24,8 +24,15 @@ import tensorflow as tf from google.protobuf import text_format from pkg_resources import parse_version -from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2, variable_pb2 +from tensorflow.core.protobuf import config_pb2, meta_graph_pb2 +from tensorflow.python.eager import context, wrap_function +from tensorflow.python.framework import convert_to_constants +from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import gfile +from tensorflow.python.saved_model import load, save, signature_constants, tag_constants +from tensorflow.python.training import saver +from tensorflow.python.util import nest from neural_compressor.utils import logger @@ -511,8 +518,6 @@ def int8_node_name_reverse(node): def tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path): """Tensorflow diagnosis helper function.""" - import tensorflow as tf - from ...utils.utility import dump_data_to_local fp32_node_mapping = {} @@ -659,3 +664,160 @@ def get_weight_from_input_tensor(model, input_tensor_names, op_types): sq_weight_tensors[name] = curr_weight_tensors sq_weights_nodes[name] = curr_weights_nodes return sq_weight_tensors, sq_weights_nodes + + +def apply_inlining(func): + """Apply an inlining optimization to the function's graph definition. + + Args: + func: A concrete function get from saved_model. + + Returns: + new_graph_def: The optimized graph in graph_def format. + """ + graph_def = func.graph.as_graph_def() + + # In some cases, a secondary implementation of the function (e.g. for GPU) is + # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in + # TF2 produces a CuDNN-based RNN for GPU). + # This function suppose to inline all functions calls, but "api_implements" + # prevents this from happening. Removing the attribute solves the problem. + # To learn more about "api_implements", see: + # tensorflow/core/grappler/optimizers/implementation_selector.h + for function in graph_def.library.function: + if "api_implements" in function.attr: + del function.attr["api_implements"] + + meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph) + + # Clear the initializer_name for the variables collections, since they are not + # needed after saved to saved_model. + for name in ["variables", "model_variables", "trainable_variables", "local_variables"]: + raw_list = [] + for raw in meta_graph.collection_def["variables"].bytes_list.value: + variable = variable_pb2.VariableDef() + variable.ParseFromString(raw) + variable.ClearField("initializer_name") + raw_list.append(variable.SerializeToString()) + meta_graph.collection_def[name].bytes_list.value[:] = raw_list + + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in func.inputs + func.outputs: + fetch_collection.node_list.value.append(array.name) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + + # Initialize RewriterConfig with everything disabled except function inlining. + config = config_pb2.ConfigProto() + rewrite_options = config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 # do not skip small graphs + rewrite_options.optimizers.append("function") + + new_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) + + return new_graph_def + + +def construct_function_from_graph_def(func, graph_def, frozen_func=None): + """Rebuild function from graph_def. + + Args: + func: The original concrete function get from saved_model. + graph_def: The optimized graph after applying inlining optimization. + + Returns: + new_func: The reconstructed function. + """ + if frozen_func is None: + frozen_func = func + + # If a function is converted, then the TF context contains the original + # function while the converted_graph_def contains the converted function. + # Remove the original function from the TF context in this case. + for f in graph_def.library.function: + while context.context().has_function(f.signature.name): + context.context().remove_function(f.signature.name) + + captures = {c[1].name.split(":")[0]: c[0] for c in frozen_func.graph.captures} + new_func = wrap_function.function_from_graph_def( + graph_def, + [tensor.name for tensor in frozen_func.inputs], + [tensor.name for tensor in frozen_func.outputs], + captures, + ) + new_func.graph.structured_outputs = nest.pack_sequence_as( + func.graph.structured_outputs, new_func.graph.structured_outputs + ) + # new_func._function_type = func.function_type # pylint: disable=protected-access + + # Copy structured input signature from original function (used during + # serialization) + new_func.graph.structured_input_signature = func.structured_input_signature + + return new_func + + +def parse_saved_model(model, freeze=False, input_tensor_names=[], output_tensor_names=[]): + """Parse a input saved_model. + + Args: + model(string or AutoTrackable object): The input saved_model. + + Returns: + graph_def: The graph_def parsed from saved_model. + _saved_model: TF AutoTrackable object loaded from saved_model. + func: The concrete function get from saved_model. + frozen_func: The reconstructed function from inlining optimized graph. + """ + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + + if isinstance(model, str): + _saved_model = load.load(model, [tag_constants.SERVING]) + else: + _saved_model = model + + func = _saved_model.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + + if freeze: + frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) + else: + inlined_graph_def = apply_inlining(func) + frozen_func = construct_function_from_graph_def(func, inlined_graph_def) + + if len(input_tensor_names) == 0: + # skip all inputs for ReadVariableOp + input_tensor_names = [i.name.split(":")[0] for i in frozen_func.inputs if "unknown" not in i.name] + if len(output_tensor_names) == 0: + output_tensor_names = [i.name.split(":")[0] for i in frozen_func.outputs] + + frozen_graph_def = frozen_func.graph.as_graph_def() + grappler_meta_graph_def = saver.export_meta_graph(graph_def=frozen_graph_def, graph=frozen_func.graph) + + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.inputs + frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + grappler_meta_graph_def.collection_def["train_op"].CopyFrom(fetch_collection) + + grappler_session_config = config_pb2.ConfigProto() + rewrite_options = grappler_session_config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 + graph_def = tf_optimizer.OptimizeGraph(grappler_session_config, grappler_meta_graph_def, graph_id=b"tf_graph") + return graph_def, _saved_model, func, frozen_func, input_tensor_names, output_tensor_names + + +def reconstruct_saved_model(graph_def, func, frozen_func, trackable, path): + """Reconstruct a saved_model. + + Args: + graph_def: The input graph_def. + func: The concrete function get from the original saved_model. + frozen_func: The reconstructed function from inlining optimized graph. + trackable: TF AutoTrackable object loaded from the original saved_model. + path: The destination path to save the reconstructed saved_model. + """ + converted_func = construct_function_from_graph_def(func, graph_def, frozen_func) + signatures = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: converted_func} + save.save(trackable, path, signatures, options=None) diff --git a/neural_compressor/model/model.py b/neural_compressor/model/model.py index 5480d29906b..2c9c6358d97 100644 --- a/neural_compressor/model/model.py +++ b/neural_compressor/model/model.py @@ -28,6 +28,7 @@ from neural_compressor.model.onnx_model import ONNXModel from neural_compressor.model.tensorflow_model import ( TensorflowBaseModel, + TensorflowLLMModel, TensorflowModel, TensorflowQATModel, get_model_type, @@ -52,8 +53,8 @@ MODELS = { "tensorflow": TensorflowModel, "tensorflow_itex": TensorflowModel, - "keras": KerasModel, "tensorflow_qat": TensorflowQATModel, + "keras": KerasModel, "mxnet": MXNetModel, "pytorch": PyTorchModel if TORCH else None, "pytorch_ipex": IPEXModel if TORCH else None, @@ -178,7 +179,10 @@ def __new__(cls, root, **kwargs): conf = kwargs.pop("conf", "NA") if isinstance(root, BaseModel): if conf != "NA" and conf.framework is None: - conf.framework = list(MODELS.keys())[list(MODELS.values()).index(type(root))] + try: + conf.framework = list(MODELS.keys())[list(MODELS.values()).index(type(root))] + except: + conf.framework = get_model_fwk_name(root._model) if hasattr(conf, "backend") and conf.backend == "ipex": assert conf.framework == "pytorch_ipex", "Please wrap the model with correct Model class!" if hasattr(conf, "backend") and conf.backend == "itex": @@ -235,6 +239,8 @@ def __new__(cls, root, **kwargs): model_type = kwargs["modelType"] else: model_type = get_model_type(root) + if model_type == "llm_saved_model": + return TensorflowLLMModel(root, **kwargs) if hasattr(conf, "backend") and conf.backend == "itex": if model_type == "keras": conf.framework = "keras" diff --git a/neural_compressor/model/tensorflow_model.py b/neural_compressor/model/tensorflow_model.py index 12253f4e78d..96b107c4683 100644 --- a/neural_compressor/model/tensorflow_model.py +++ b/neural_compressor/model/tensorflow_model.py @@ -23,6 +23,7 @@ import shutil import sys import tempfile +import time from abc import abstractmethod from neural_compressor import config as cfg @@ -243,13 +244,15 @@ def _contains_function_with_implements_attr(saved_model_proto): return False -def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names): +def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names): # pragma: no cover """Load graph_def from saved model with the default serving signature key. Args: - saved_model_dir: Directory of the SavedModel. + model: Directory of the SavedModel. saved_model_tags: Set of tags identifying the MetaGraphDef within the SavedModel to analyze. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. Returns: graph_def: The loaded GraphDef. @@ -310,9 +313,17 @@ def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_ def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names): from tensorflow.python.saved_model import signature_constants, tag_constants + from neural_compressor.adaptor.tf_utils.util import parse_saved_model + saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] saved_model_tags = set([tag_constants.SERVING]) - return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) + try: + graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model( + saved_model_dir, True, input_tensor_names, output_tensor_names + ) + except: + return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) + return graph_def, input_names, output_names def _get_graph_from_original_keras_v2(model, output_dir): @@ -452,7 +463,7 @@ def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir) except: keras_format = "saved_model_v1" - if keras_format == "saved_model_v1": + if keras_format == "saved_model_v1": # pragma: no cover try: tf.keras.backend.set_learning_phase(0) graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) @@ -647,6 +658,7 @@ def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs "graph_def": graph_def_session, "graph": graph_session, "saved_model": saved_model_session, + "llm_saved_model": saved_model_session, "keras": keras_session, "checkpoint": checkpoint_session, "estimator": estimator_session, @@ -933,6 +945,15 @@ def export(self, save_path, conf): class TensorflowSavedModelModel(TensorflowBaseModel): """Build Tensorflow saved model.""" + def __init__(self, model, **kwargs): + """Initialize a Tensorflow model. + + Args: + model (string or tensorflow model object): model path or model object. + """ + super(TensorflowSavedModelModel, self).__init__(model, **kwargs) + self._auto_trackable = None + def get_all_weight_names(self): """Get weight names of model. @@ -961,9 +982,9 @@ def get_weight(self, tensor_name): @property def model(self): - """Return model itself.""" - import shutil - import time + """Return model in AutoTrackable object.""" + if self._auto_trackable: + return self._auto_trackable root = os.path.abspath(os.path.expanduser(cfg.default_workspace)) root += str(time.time()) @@ -976,8 +997,14 @@ def model(self): builder.save() model = tf.saved_model.load(root) shutil.rmtree(root) + self._auto_trackable = model return model + @model.setter + def model(self, input_model): + """Set model in AutoTrackable object.""" + self._auto_trackable = input_model + def report_sparsity(self): """Get sparsity of the model. @@ -1074,6 +1101,189 @@ def save(self, root=None): logger.info("Save quantized model to {}.".format(root)) +class TensorflowLLMModel(TensorflowSavedModelModel): + """The class Tensorflow saved model whose GraphDef exceeding maximum protobuf size of 2GB.""" + + def __init__(self, model, **kwargs): + """Initialize a Tensorflow model. + + Args: + model (string or tensorflow model object): model path or model object. + """ + super(TensorflowLLMModel, self).__init__(model, **kwargs) + + self._model_path = self.kwargs.get("model_path", None) + self._weight_name_mapping = self.kwargs.get("weight_name_mapping", None) + self._sq_weight_scale_dict = self.kwargs.get("sq_weight_scale_dict", None) + self._weight_tensor_minmax_dict = {} + self._model_type = "llm_saved_model" + + from neural_compressor.adaptor.tf_utils.util import parse_saved_model + + ( + self._graph_def, + self._saved_model, + self.func, + self.frozen_func, + self._input_tensor_names, + self._output_tensor_names, + ) = parse_saved_model(model) + + @property + def model_path(self): + """Return model path. + + The model path in this class is used as a temp path for intermediate model + """ + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path. + + The model path in this class is used as a temp path for intermediate model + """ + self.kwargs.update({"model_path": path}) + self._model_path = path + + @property + def graph_def(self): + """Return graph_def.""" + return self._graph_def + + @graph_def.setter + def graph_def(self, graph_def): + """Set graph definition.""" + self._graph_def = graph_def + # the attributes of some nodes can't be correctly read if don't import the graph_def + tf.import_graph_def(self._graph_def, name="") + + @property + def model(self): + """Return model in AutoTrackable Format.""" + if self._sq_weight_scale_dict: + self.adjust_weight(self.graph_def) + if not self._auto_trackable: + self._auto_trackable = tf.saved_model.load(self._model) + return self._auto_trackable + + @property + def weight_name_mapping(self): + """Return weight_name_mapping function.""" + if not self._weight_name_mapping: + self._weight_name_mapping = self.kwargs.get("weight_name_mapping", None) + assert self._weight_name_mapping is not None, "weight_name_mapping should not be None!" + return self._weight_name_mapping + + @weight_name_mapping.setter + def weight_name_mapping(self, weight_name_mapping): + """Set weight_name_mapping function.""" + self.kwargs.update({"weight_name_mapping": weight_name_mapping}) + self._weight_name_mapping = weight_name_mapping + + @property + def sq_weight_scale_dict(self): + """Return dict of weight scaler for smooth quantization.""" + if not self._sq_weight_scale_dict: + self._sq_weight_scale_dict = self.kwargs.get("sq_weight_scale_dict", None) + assert self._weight_name_mapping is not None, "sq_weight_scale_dict should not be None!" + return self._sq_weight_scale_dict + + @sq_weight_scale_dict.setter + def sq_weight_scale_dict(self, sq_weight_scale_dict): + """Set dict of weight scaler for smooth quantization.""" + self.kwargs.update({"sq_weight_scale_dict": sq_weight_scale_dict}) + self._sq_weight_scale_dict = sq_weight_scale_dict + + @property + def weight_tensor_minmax_dict(self): + """Return dict of weight scaler for smooth quantization.""" + return self._weight_tensor_minmax_dict + + @property + def input_tensor_names(self): + """Return input tensor names.""" + return copy.deepcopy(self._input_tensor_names) + + @input_tensor_names.setter + def input_tensor_names(self, tensor_names): + """Set input tensor names.""" + if len(tensor_names) == 0: # pragma: no cover + logger.warn("Input tensor names is empty.") + return + + assert validate_graph_node( + self._graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._input_tensor_names = tensor_names + + @property + def output_tensor_names(self): + """Return output tensor names.""" + return copy.deepcopy(self._output_tensor_names) + + @output_tensor_names.setter + def output_tensor_names(self, tensor_names): + """Set output tensor names.""" + if len(tensor_names) == 0: # pragma: no cover + logger.warn("Output tensor names is empty.") + return + if self._graph_def is not None: + assert validate_graph_node( + self.graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._output_tensor_names = tensor_names + + @property + def output_node_names(self): + """Return output node names.""" + output_node_names = tensor_to_node(self.output_tensor_names) + return copy.deepcopy(output_node_names) + + def adjust_weight(self, graph_def): + """Adjust weight of LLM saved_model by scale.""" + from tensorflow.python.saved_model import load, tag_constants + + from neural_compressor.adaptor.tf_utils.util import reconstruct_saved_model + + reconstruct_saved_model(graph_def, self.func, self.frozen_func, self._saved_model, self.model_path) + model = load.load(self.model_path, [tag_constants.SERVING]) + + for idx, weight_tensor in enumerate(model.variables): + parsed_weight_name = self.weight_name_mapping(weight_tensor.name) + if parsed_weight_name in self.sq_weight_scale_dict: + weight_array = np.transpose(weight_tensor, [1, 0]) + weight_array *= self.sq_weight_scale_dict[parsed_weight_name] + weight_array = np.transpose(weight_array, [1, 0]) + tf.compat.v1.assign(model.variables[idx], weight_array) + else: + weight_array = weight_tensor + + if parsed_weight_name not in self._weight_tensor_minmax_dict: + self._weight_tensor_minmax_dict[parsed_weight_name] = [np.min(weight_array), np.max(weight_array)] + self._auto_trackable = model + + def save(self, root=None): + """Save the model to the root path.""" + import shutil + + from neural_compressor.adaptor.tf_utils.util import parse_saved_model, reconstruct_saved_model + + if not root: + root = cfg.default_workspace + root = os.path.abspath(os.path.expanduser(root)) + if os.path.exists(root): + shutil.rmtree(root) + os.makedirs(root, exist_ok=True) + + self.adjust_weight(self._graph_def) + graph_def, _saved_model, func, frozen_func, _, _ = parse_saved_model(self._auto_trackable) + reconstruct_saved_model(graph_def, func, frozen_func, _saved_model, root) + logger.info("Save quantized model to {}.".format(root)) + # delete the LLM file saved in this temporary path + shutil.rmtree(self.model_path, ignore_errors=True) + + class TensorflowQATModel(TensorflowSavedModelModel): """Build Tensorflow QAT model.""" @@ -1181,6 +1391,7 @@ def model(self): "slim": TensorflowBaseModel, "saved_model": TensorflowSavedModelModel, "keras": TensorflowSavedModelModel, + "llm_saved_model": TensorflowLLMModel, } diff --git a/test/tfnewapi/test_big_saved_model.py b/test/tfnewapi/test_big_saved_model.py new file mode 100644 index 00000000000..eac046079fd --- /dev/null +++ b/test/tfnewapi/test_big_saved_model.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import time +import unittest + +import numpy as np +import tensorflow as tf +from tensorflow import keras + +from neural_compressor.data.dataloaders.dataloader import DataLoader + + +def build_model(): + # Load MNIST dataset + mnist = keras.datasets.mnist + # 60000 images in train and 10000 images in test, but we don't need so much for ut + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + train_images, train_labels = train_images[:1000], train_labels[:1000] + test_images, test_labels = test_images[:200], test_labels[:200] + # Normalize the input image so that each pixel value is between 0 to 1. + train_images = train_images / 255.0 + test_images = test_images / 255.0 + + # Define the model architecture. + model = keras.Sequential( + [ + keras.layers.InputLayer(input_shape=(28, 28)), + keras.layers.Reshape(target_shape=(28, 28, 1)), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"), + keras.layers.MaxPooling2D(pool_size=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense(10), + ] + ) + # Train the digit classification model + model.compile( + optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + ) + + model.fit( + train_images, + train_labels, + epochs=1, + validation_split=0.1, + ) + + _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) + + print("Baseline test accuracy:", baseline_model_accuracy) + model.save("baseline_model") + + +class Dataset(object): + def __init__(self, batch_size=100): + mnist = keras.datasets.mnist + # 60000 images in train and 10000 images in test, but we don't need so much for ut + (train_images, train_labels), (test_images, test_labels) = mnist.load_data() + train_images, train_labels = train_images[:1000], train_labels[:1000] + test_images, test_labels = test_images[:200], test_labels[:200] + # Normalize the input image so that each pixel value is between 0 to 1. + self.train_images = train_images / 255.0 + self.test_images = test_images / 255.0 + self.train_labels = train_labels + self.test_labels = test_labels + + def __len__(self): + return len(self.test_images) + + def __getitem__(self, idx): + return self.test_images[idx].astype(np.float32), self.test_labels[idx] + + +class TestBigSavedModel(unittest.TestCase): + @classmethod + def setUpClass(self): + build_model() + + @classmethod + def tearDownClass(self): + shutil.rmtree("baseline_model", ignore_errors=True) + shutil.rmtree("int8_model", ignore_errors=True) + + def test_newapi_sq_big_saved_model(self): + def weight_name_mapping(name): + """The function that maps name from AutoTrackable variables to graph nodes.""" + name = name.replace("dense", "StatefulPartitionedCall/sequential/dense/MatMul") + name = name.replace("conv2d", "StatefulPartitionedCall/sequential/conv2d/Conv2D") + name = name.replace("kernel:0", "ReadVariableOp") + return name + + from neural_compressor import Model, quantization + from neural_compressor.config import PostTrainingQuantConfig + + model = Model("baseline_model", modelType="llm_saved_model") + model.weight_name_mapping = weight_name_mapping + + output_node_names = model.output_node_names + self.assertEqual(output_node_names, ["Identity"]) + + calib_dataloader = DataLoader(framework="tensorflow", dataset=Dataset()) + recipes = {"smooth_quant": True, "smooth_quant_args": {"alpha": 0.6}} + op_name_dict = { + "StatefulPartitionedCall/sequential/conv2d/Conv2D": { + "weight": {"dtype": ["fp32"]}, + "activation": {"dtype": ["fp32"]}, + } + } + config = PostTrainingQuantConfig( + quant_level=1, + recipes=recipes, + op_name_dict=op_name_dict, + calibration_sampling_size=[500], + ) + model.weight_name_mapping = weight_name_mapping + q_model = quantization.fit(model, config, calib_dataloader=calib_dataloader) + q_model.save("int8_model") + quant_count = 0 + for i in q_model.graph_def.node: + if i.op == "QuantizeV2": + quant_count += 1 + + self.assertEqual(quant_count, 3) + + +if __name__ == "__main__": + unittest.main()