Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstring for PT2E and HQQ #1937

Merged
merged 7 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
/neural-compressor/neural_compressor/strategy
/neural-compressor/neural_compressor/training.py
/neural-compressor/neural_compressor/utils
/neural_compressor/torch/algorithms/pt2e_quant
/neural_compressor/torch/export
/neural_compressor/common
/neural_compressor/torch/algorithms/weight_only/hqq
1 change: 1 addition & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The PT2E-related modules."""


from neural_compressor.torch.algorithms.pt2e_quant.core import W8A8PT2EQuantizer
Expand Down
39 changes: 36 additions & 3 deletions neural_compressor/torch/algorithms/pt2e_quant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Some code snippets are taken from the X86InductorQuantizer tutorial.
# https://pytorch.org/tutorials/prototype/pt2e_quant_x86_inductor.html

"""The quantizer using PT2E path."""

from typing import Any

Expand All @@ -30,13 +30,24 @@


class W8A8PT2EQuantizer(Quantizer):
"""The W8A8 quantizer using PT2E."""

is_dynamic = False

def __init__(self, quant_config=None):
"""Initialize the quantizer."""
super().__init__(quant_config)

@staticmethod
def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuantizer:
"""Updates the quantizer based on the given quantization configuration.

Args:
quant_config (dict): The quantization configuration. Defaults to None.

Returns:
X86InductorQuantizer: The updated quantizer object.
"""
if not quant_config:
quantizer = X86InductorQuantizer()
quantizer.set_global(
Expand All @@ -47,9 +58,18 @@ def update_quantizer_based_on_quant_config(quant_config=None) -> X86InductorQuan
return quantizer

def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args, **kwargs) -> GraphModule:
"""Prepare the model for calibration.
"""Prepares the model for calibration.

Create the `quantizer` according to the `quant_config`, and insert the observers accordingly.

Args:
model (GraphModule): The model to be prepared for calibration.
example_inputs (tuple, optional): Example inputs to be used for calibration. Defaults to None.
inplace (bool, optional): Whether to modify the model in-place or return a new prepared model.
Defaults to True.

Returns:
GraphModule: The prepared model.
"""
quant_config = self.quant_config
assert model._exported, "The model should be exported before preparing it for calibration."
Expand All @@ -58,7 +78,14 @@ def prepare(self, model: GraphModule, example_inputs=None, inplace=True, *args,
return prepared_model

def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
"""Convert the calibrated model into qdq mode."""
"""Convert the calibrated model into qdq mode.

Args:
model (GraphModule): The prepared model.

Returns:
GraphModule: The converted quantized model.
"""
fold_quantize = kwargs.get("fold_quantize", False)
converted_model = convert_pt2e(model, fold_quantize=fold_quantize)
logger.warning("Converted the model in qdq mode, please compile it to accelerate inference.")
Expand All @@ -67,6 +94,12 @@ def convert(self, model: GraphModule, *args: Any, **kwargs: Any) -> GraphModule:
return converted_model

def half_precision_transformation(self, model, config):
"""Applies half-precision transformation to the given model in-place.

Args:
model: The model to apply the transformation to.
config: The configuration for the transformation.
"""
half_precision_node_set = hp_rewriter.get_half_precision_node_set(model, config)
logger.info("Try to convert %d nodes to half precision.", len(half_precision_node_set))
hp_rewriter.transformation(model, half_precision_node_set)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rewrite the FP32 operators to FP16 or BF16 operators."""

from dataclasses import dataclass
from functools import partial
Expand All @@ -34,6 +35,14 @@

@dataclass
class PatternPair:
"""Represents a pair of patterns used for search and replacement in a graph.

Attributes:
fn (TorchFuncType): The function type associated with the pattern pair.
search_pattern (torch.fx.GraphModule): The search pattern to be matched in the graph.
replace_pattern (torch.fx.GraphModule): The replacement pattern to be used when a match is found.
"""

fn: TorchFuncType
search_pattern: torch.fx.GraphModule
replace_pattern: torch.fx.GraphModule
Expand Down Expand Up @@ -101,6 +110,15 @@ def _register_pattern_pair(dtype: torch.dtype) -> None:


def get_filter_fn(node_list, fn):
"""Filter function to check if a node with the target operator is in the given `node_list`.

Args:
node_list (list): List of nodes to check against.
fn (str): Target operator.

Returns:
bool: True if the node with the target operator is in the `node_list`, False otherwise.
"""
target_op = FN_ATEN_OPS_MAPPING[fn]

def is_target_node_in_candidate_list(match, original_graph, pattern_graph):
Expand All @@ -119,6 +137,16 @@ def is_target_node_in_candidate_list(match, original_graph, pattern_graph):


def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPair, node_list):
"""Applies a single pattern pair to a given GraphModule.

Args:
gm (torch.fx.GraphModule): The GraphModule to apply the pattern pair to.
pattern_pair (PatternPair): The pattern pair containing the search and replace patterns.
node_list: The list of nodes to filter for pattern matching.

Returns:
List[Match]: A list of Match objects representing the matches found after applying the pattern pair.
"""
filter_fn = get_filter_fn(node_list, pattern_pair.fn)
match_and_replacements = subgraph_rewriter.replace_pattern_with_filters(
gm=gm,
Expand All @@ -133,6 +161,14 @@ def apply_single_pattern_pair(gm: torch.fx.GraphModule, pattern_pair: PatternPai


def get_unquantized_node_set(gm: torch.fx.GraphModule):
"""Retrieves the set of unquantized nodes from a given GraphModule.

Args:
gm (torch.fx.GraphModule): The GraphModule to retrieve unquantized nodes from.

Returns:
set: A set containing the unquantized nodes.
"""
unquantized_node_set = set()
for node in gm.graph.nodes:
if meta := getattr(node, "meta"):
Expand Down Expand Up @@ -180,7 +216,17 @@ def _parse_node_candidate_set_from_user_config(config, gm):


def get_half_precision_node_set(gm, config):
"""Intersection between `unquantized_node_set` and `node_set_from_user_config`"""
"""Retrieves a set of nodes from the given graph model (gm) that are candidates for conversion to half precision.

The result is the intersection between `unquantized_node_set` and `node_set_from_user_config`.

Args:
gm (GraphModel): The graph model to search for nodes.
config (dict): User configuration for node candidate set.

Returns:
set: A set of nodes that are candidates for conversion to half precision.
"""
# TODO: implement it, current return all unquantized_node_set

node_set_from_user_config = _parse_node_candidate_set_from_user_config(config, gm)
Expand Down
17 changes: 17 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Save and load the quantized model."""


import json
import os
Expand All @@ -22,6 +24,13 @@


def save(model, example_inputs, output_dir="./saved_results"):
"""Save the quantized model and its configuration.

Args:
model (torch.nn.Module): The quantized model to be saved.
example_inputs (torch.Tensor or tuple of torch.Tensor): Example inputs used for tracing the model.
output_dir (str, optional): The directory where the saved results will be stored. Defaults to "./saved_results".
"""
os.makedirs(output_dir, exist_ok=True)
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
qconfig_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), QCONFIG_NAME)
Expand All @@ -37,6 +46,14 @@ def save(model, example_inputs, output_dir="./saved_results"):


def load(output_dir="./saved_results"):
"""Load a quantized model from the specified output directory.

Args:
output_dir (str): The directory where the quantized model is saved. Defaults to "./saved_results".

Returns:
torch.nn.Module: The loaded quantized model.
"""
qmodel_file_path = os.path.join(os.path.abspath(os.path.expanduser(output_dir)), WEIGHT_NAME)
loaded_quantized_ep = torch.export.load(qmodel_file_path)
return loaded_quantized_ep.module()
22 changes: 22 additions & 0 deletions neural_compressor/torch/algorithms/pt2e_quant/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for PT2E quantization."""

from typing import Dict

Expand All @@ -24,6 +25,18 @@


def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
"""Create a quantization specification based on the given configuration.

Args:
dtype (str): The desired data type for quantization. Valid options are "int8" and "uint8".
sym (bool): Whether to use symmetric quantization or not.
granularity (str): The granularity of quantization. Valid options are "per_channel" and "per_tensor".
algo (str): The algorithm to use for quantization. Valid options are "placeholder", "minmax", and "kl".
is_dynamic (bool, optional): Whether to use dynamic quantization or not. Defaults to False.

Returns:
QuantizationSpec: The created quantization specification.
"""
dtype_mapping: Dict[str, torch.dtype] = {"int8": torch.int8, "uint8": torch.uint8}
select_dtype = dtype_mapping[dtype]
min_max_mapping = {torch.int8: (-128, 127), torch.uint8: (0, 255)}
Expand Down Expand Up @@ -76,6 +89,15 @@ def _map_inc_config_to_torch_quant_config(inc_config, is_dynamic=False) -> Quant


def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86InductorQuantizer:
"""Creates an instance of X86InductorQuantizer based on the given configuration.

Args:
config: The configuration object containing the quantization settings.
is_dynamic: A boolean indicating whether dynamic quantization is enabled.

Returns:
An instance of X86InductorQuantizer initialized with the provided configuration.
"""
quantizer = xiq.X86InductorQuantizer()
# set global
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HQQ-related modules."""

from .quantizer import HQQuantizer
from .config import HQQModuleConfig, QTensorConfig
Loading
Loading