forked from openvinotoolkit/nncf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement FracBitsQuantizationBuilder and Controller (openvinotoolkit…
…#1234) * Implement FracBitsQuantizationBuilder and Controller - Implement Builder and Controller - Add and test ModelSizeCompressionLoss Signed-off-by: Kim, Vinnam <[email protected]>
- Loading branch information
Showing
12 changed files
with
859 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from nncf.experimental.torch.fracbits.controller import FracBitsQuantizationController | ||
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS | ||
from nncf.torch.compression_method_api import PTCompressionAlgorithmController | ||
from nncf.torch.nncf_network import NNCFNetwork | ||
from nncf.torch.quantization.algo import QuantizationBuilder | ||
from nncf.torch.quantization.layers import PTQuantizerSetup | ||
from nncf.common.quantization.structs import QuantizationMode | ||
from nncf.experimental.torch.fracbits.quantizer import FracBitsQuantizationMode | ||
|
||
|
||
@PT_COMPRESSION_ALGORITHMS.register('fracbits_quantization') | ||
class FracBitsQuantizationBuilder(QuantizationBuilder): | ||
def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup: | ||
setup = super()._get_quantizer_setup(target_model) | ||
|
||
for q_point in setup.quantization_points.values(): | ||
mode = q_point.qspec.mode | ||
if mode == QuantizationMode.ASYMMETRIC: | ||
q_point.qspec.mode = FracBitsQuantizationMode.ASYMMETRIC | ||
elif mode == QuantizationMode.SYMMETRIC: | ||
q_point.qspec.mode = FracBitsQuantizationMode.SYMMETRIC | ||
else: | ||
raise ValueError(f"qsepc.mode={mode} is unknown.") | ||
|
||
return setup | ||
|
||
def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController: | ||
return FracBitsQuantizationController(model, | ||
self.config, | ||
self._debug_interface, | ||
self._weight_quantizers, | ||
self._non_weight_quantizers, | ||
self._groups_of_adjacent_quantizers, | ||
self._quantizers_input_shapes, | ||
build_time_metric_info=self._build_time_metric_infos, | ||
build_time_range_init_params=self._range_init_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from contextlib import contextmanager | ||
from typing import Dict, Tuple | ||
|
||
from nncf.common.quantization.structs import NonWeightQuantizerId, QuantizerId, WeightQuantizerId | ||
from nncf.common.statistics import NNCFStatistics | ||
from nncf.config.config import NNCFConfig | ||
from nncf.config.extractors import extract_algo_specific_config | ||
from nncf.experimental.torch.fracbits.statistics import FracBitsStatistics | ||
from nncf.experimental.torch.fracbits.scheduler import FracBitsQuantizationScheduler | ||
from nncf.torch.compression_method_api import PTCompressionLoss | ||
from nncf.torch.nncf_network import NNCFNetwork | ||
from nncf.torch.quantization.algo import QuantizationController, QuantizationDebugInterface | ||
from nncf.torch.quantization.init_range import PTRangeInitParams | ||
from nncf.torch.quantization.metrics import QuantizationShareBuildTimeInfo | ||
from nncf.torch.quantization.precision_init.adjacent_quantizers import GroupsOfAdjacentQuantizers | ||
from nncf.torch.quantization.structs import NonWeightQuantizerInfo, WeightQuantizerInfo | ||
from nncf.experimental.torch.fracbits.loss import FRACBITS_LOSSES | ||
|
||
|
||
class FracBitsQuantizationController(QuantizationController): | ||
def __init__(self, target_model: NNCFNetwork, | ||
config: NNCFConfig, | ||
debug_interface: QuantizationDebugInterface, | ||
weight_quantizers: Dict[WeightQuantizerId, WeightQuantizerInfo], | ||
non_weight_quantizers: Dict[NonWeightQuantizerId, NonWeightQuantizerInfo], | ||
groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers, | ||
quantizers_input_shapes: Dict[QuantizerId, Tuple[int]], | ||
build_time_metric_info: QuantizationShareBuildTimeInfo = None, | ||
build_time_range_init_params: PTRangeInitParams = None): | ||
super().__init__(target_model, config, debug_interface, weight_quantizers, non_weight_quantizers, | ||
groups_of_adjacent_quantizers, quantizers_input_shapes, | ||
build_time_metric_info, build_time_range_init_params) | ||
self._set_fracbits_loss(target_model) | ||
self._set_scheduler() | ||
|
||
def _set_fracbits_loss(self, target_model: NNCFNetwork): | ||
algo_config = self._get_algo_config() | ||
|
||
if algo_config.get("loss") is None: | ||
raise RuntimeError("You didn't set loss config.") | ||
|
||
loss_config = algo_config.get("loss") | ||
|
||
if loss_config.get("type") is None: | ||
raise RuntimeError("You didn't set loss.type config.") | ||
|
||
if loss_config.get("compression_rate") is None: | ||
raise RuntimeError("You didn't set compression_rate config.") | ||
|
||
if loss_config.get("criteria") is None: | ||
raise RuntimeError("You didn't set criteria config.") | ||
|
||
loss_type = loss_config.get("type") | ||
compression_rate = loss_config.get("compression_rate") | ||
criteria = loss_config.get("criteria") | ||
|
||
self._loss: PTCompressionLoss = FRACBITS_LOSSES.get(loss_type)( | ||
model=target_model, compression_rate=compression_rate, criteria=criteria) | ||
|
||
def _set_scheduler(self): | ||
algo_config = self._get_algo_config() | ||
|
||
freeze_epoch = algo_config.get("freeze_epoch", None) | ||
|
||
if freeze_epoch is None: | ||
raise RuntimeError("You didn't set freeze_epoch config.") | ||
|
||
def _callback(): | ||
self.freeze_bit_widths() | ||
|
||
self._scheduler = FracBitsQuantizationScheduler( | ||
freeze_callback=_callback, | ||
freeze_epoch=freeze_epoch | ||
) | ||
|
||
def _get_algo_config(self) -> Dict: | ||
return extract_algo_specific_config(self.config, algo_name_to_match="fracbits_quantization") | ||
|
||
def freeze_bit_widths(self): | ||
for q in self.all_quantizations.values(): | ||
q.freeze_num_bits() | ||
|
||
def statistics(self, quickly_collected_only=False) -> NNCFStatistics: | ||
@contextmanager | ||
def _base_name_context(): | ||
tmp_name = self._name | ||
self._name = "quantization" | ||
yield self.name | ||
self._name = tmp_name | ||
|
||
with _base_name_context(): | ||
nncf_statistics = super().statistics(quickly_collected_only) | ||
|
||
nncf_statistics.register(self.name, FracBitsStatistics(self._loss.get_state())) | ||
|
||
return nncf_statistics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
Copyright (c) 2022 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from numbers import Number | ||
from typing import Dict, Union | ||
import torch | ||
from nncf.common.utils.registry import Registry | ||
from nncf.torch.compression_method_api import PTCompressionLoss | ||
from nncf.torch.module_operations import UpdateWeight | ||
from nncf.torch.nncf_network import NNCFNetwork | ||
from torch import nn | ||
from dataclasses import dataclass | ||
from nncf.torch.quantization.layers import BaseQuantizer | ||
from nncf.common.utils.logger import logger as nncf_logger | ||
|
||
|
||
FRACBITS_LOSSES = Registry("fracbits_loss") | ||
EPS = 1e-6 | ||
|
||
|
||
@dataclass | ||
class ModuleQuantizerPair: | ||
module: nn.Module | ||
quantizer: BaseQuantizer | ||
|
||
|
||
@FRACBITS_LOSSES.register("model_size") | ||
class ModelSizeCompressionLoss(PTCompressionLoss): | ||
def __init__(self, model: NNCFNetwork, compression_rate: float, criteria: str = "L1", **kwargs): | ||
super().__init__() | ||
self._model = model | ||
self._compression_rate = compression_rate | ||
self._criteria = self._get_criteria(criteria) | ||
|
||
self._w_q_pairs: Dict[str, ModuleQuantizerPair] = {} | ||
|
||
for name, module in self._model.named_modules(): | ||
if isinstance(module, UpdateWeight): | ||
parent_name = ".".join(name.split(".")[:-2]) | ||
parent_module = self._model.get_submodule(parent_name) | ||
|
||
self._w_q_pairs[parent_name] = ModuleQuantizerPair(parent_module, module.op) | ||
|
||
with torch.no_grad(): | ||
self._init_model_size = self._get_model_size() | ||
|
||
def calculate(self) -> torch.Tensor: | ||
cur_comp_rate = self._init_model_size / (self._get_model_size() + EPS) | ||
tgt_comp_rate = torch.full_like(cur_comp_rate, self._compression_rate) | ||
|
||
return self._criteria(cur_comp_rate, tgt_comp_rate) | ||
|
||
def _get_criteria(self, criteria) -> nn.modules.loss._Loss: | ||
if criteria == "L1": | ||
return nn.L1Loss() | ||
if criteria == "L2": | ||
return nn.MSELoss() | ||
raise RuntimeError(f"Unknown criteria = {criteria}.") | ||
|
||
def _get_model_size(self) -> Union[torch.Tensor, Number]: | ||
def _get_module_size(module: nn.Module, num_bits: Union[int, torch.Tensor]) -> Union[torch.Tensor, Number]: | ||
if isinstance(module, (nn.modules.conv._ConvNd, nn.Linear)): | ||
return (module.weight.shape.numel() * num_bits).sum() | ||
nncf_logger.warning("module={module} is not supported by ModelSizeCompressionLoss. Skip it.") | ||
return 0. | ||
|
||
return sum([_get_module_size(pair.module, pair.quantizer.frac_num_bits) for pair in self._w_q_pairs.values()]) | ||
|
||
@torch.no_grad() | ||
def get_state(self) -> Dict[str, Number]: | ||
states = { | ||
"compression_rate": self._init_model_size / (self._get_model_size() + EPS).item() | ||
} | ||
|
||
for name, pair in self._w_q_pairs.items(): | ||
states[f"frac_bits/{name}"] = pair.quantizer.frac_num_bits.item() | ||
|
||
return states | ||
|
||
|
||
@FRACBITS_LOSSES.register("bitops") | ||
class BitOpsCompressionLoss(PTCompressionLoss): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def calculate(self) -> torch.Tensor: | ||
raise NotImplementedError() | ||
|
||
@torch.no_grad() | ||
def get_state(self) -> Dict[str, Number]: | ||
raise NotImplementedError() |
Oops, something went wrong.