Skip to content

Commit

Permalink
Implement FracBitsQuantizationBuilder and Controller (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#1234)

* Implement FracBitsQuantizationBuilder and Controller

 - Implement Builder and Controller
 - Add and test ModelSizeCompressionLoss

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim committed Oct 4, 2022
1 parent e9f2086 commit 7bd7b87
Show file tree
Hide file tree
Showing 12 changed files with 859 additions and 5 deletions.
3 changes: 2 additions & 1 deletion nncf/common/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def register(self, algorithm_name: str, stats: Statistics):
- quantization
- filter_pruning
- binarization
- fracbits_quantization
:param stats: Statistics of the algorithm.
"""

available_algorithms = [
'magnitude_sparsity', 'rb_sparsity', 'const_sparsity',
'quantization', 'filter_pruning', 'binarization'
'quantization', 'filter_pruning', 'binarization', "fracbits_quantization"
]
if algorithm_name not in available_algorithms:
raise ValueError('Can not register statistics for the algorithm. '
Expand Down
22 changes: 21 additions & 1 deletion nncf/config/schemata/experimental_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,31 @@
"additionalProperties": False
}

########################################################################################################################
# FracBits Quantization
########################################################################################################################
FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG = 'fracbits_quantization'
FRACBITS_QUANTIZATION_SCHEMA = copy.deepcopy(QUANTIZATION_SCHEMA)
FRACBITS_QUANTIZATION_SCHEMA['properties']['algorithm']['const'] = FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG
FRACBITS_QUANTIZATION_SCHEMA['properties']['freeze_epoch'] = with_attributes(
NUMBER, description="The number of epoch to freeze fractional bit widths to integers by rounding them.")
FRACBITS_QUANTIZATION_SCHEMA['properties']['loss'] = {
"type": "object",
"properties": {
"type": with_attributes(STRING, description="Type of compression loss. Choose model_size or bitops."),
"compression_rate": with_attributes(NUMBER, description="Target compression rate"),
"criteria": with_attributes(STRING, description="Criteria to measure the distance between the target "
"compression rate and the currrent compression rate. Choose L1 or L2."),
},
"additionalProperties": False
}

########################################################################################################################
# All experimental schemas
########################################################################################################################

EXPERIMENTAL_REF_VS_ALGO_SCHEMA = {
EXPERIMENTAL_QUANTIZATION_ALGO_NAME_IN_CONFIG: EXPERIMENTAL_QUANTIZATION_SCHEMA,
BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA
BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA,
FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG: FRACBITS_QUANTIZATION_SCHEMA
}
49 changes: 49 additions & 0 deletions nncf/experimental/torch/fracbits/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from nncf.experimental.torch.fracbits.controller import FracBitsQuantizationController
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.algo import QuantizationBuilder
from nncf.torch.quantization.layers import PTQuantizerSetup
from nncf.common.quantization.structs import QuantizationMode
from nncf.experimental.torch.fracbits.quantizer import FracBitsQuantizationMode


@PT_COMPRESSION_ALGORITHMS.register('fracbits_quantization')
class FracBitsQuantizationBuilder(QuantizationBuilder):
def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup:
setup = super()._get_quantizer_setup(target_model)

for q_point in setup.quantization_points.values():
mode = q_point.qspec.mode
if mode == QuantizationMode.ASYMMETRIC:
q_point.qspec.mode = FracBitsQuantizationMode.ASYMMETRIC
elif mode == QuantizationMode.SYMMETRIC:
q_point.qspec.mode = FracBitsQuantizationMode.SYMMETRIC
else:
raise ValueError(f"qsepc.mode={mode} is unknown.")

return setup

def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
return FracBitsQuantizationController(model,
self.config,
self._debug_interface,
self._weight_quantizers,
self._non_weight_quantizers,
self._groups_of_adjacent_quantizers,
self._quantizers_input_shapes,
build_time_metric_info=self._build_time_metric_infos,
build_time_range_init_params=self._range_init_params)
109 changes: 109 additions & 0 deletions nncf/experimental/torch/fracbits/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from contextlib import contextmanager
from typing import Dict, Tuple

from nncf.common.quantization.structs import NonWeightQuantizerId, QuantizerId, WeightQuantizerId
from nncf.common.statistics import NNCFStatistics
from nncf.config.config import NNCFConfig
from nncf.config.extractors import extract_algo_specific_config
from nncf.experimental.torch.fracbits.statistics import FracBitsStatistics
from nncf.experimental.torch.fracbits.scheduler import FracBitsQuantizationScheduler
from nncf.torch.compression_method_api import PTCompressionLoss
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.algo import QuantizationController, QuantizationDebugInterface
from nncf.torch.quantization.init_range import PTRangeInitParams
from nncf.torch.quantization.metrics import QuantizationShareBuildTimeInfo
from nncf.torch.quantization.precision_init.adjacent_quantizers import GroupsOfAdjacentQuantizers
from nncf.torch.quantization.structs import NonWeightQuantizerInfo, WeightQuantizerInfo
from nncf.experimental.torch.fracbits.loss import FRACBITS_LOSSES


class FracBitsQuantizationController(QuantizationController):
def __init__(self, target_model: NNCFNetwork,
config: NNCFConfig,
debug_interface: QuantizationDebugInterface,
weight_quantizers: Dict[WeightQuantizerId, WeightQuantizerInfo],
non_weight_quantizers: Dict[NonWeightQuantizerId, NonWeightQuantizerInfo],
groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers,
quantizers_input_shapes: Dict[QuantizerId, Tuple[int]],
build_time_metric_info: QuantizationShareBuildTimeInfo = None,
build_time_range_init_params: PTRangeInitParams = None):
super().__init__(target_model, config, debug_interface, weight_quantizers, non_weight_quantizers,
groups_of_adjacent_quantizers, quantizers_input_shapes,
build_time_metric_info, build_time_range_init_params)
self._set_fracbits_loss(target_model)
self._set_scheduler()

def _set_fracbits_loss(self, target_model: NNCFNetwork):
algo_config = self._get_algo_config()

if algo_config.get("loss") is None:
raise RuntimeError("You didn't set loss config.")

loss_config = algo_config.get("loss")

if loss_config.get("type") is None:
raise RuntimeError("You didn't set loss.type config.")

if loss_config.get("compression_rate") is None:
raise RuntimeError("You didn't set compression_rate config.")

if loss_config.get("criteria") is None:
raise RuntimeError("You didn't set criteria config.")

loss_type = loss_config.get("type")
compression_rate = loss_config.get("compression_rate")
criteria = loss_config.get("criteria")

self._loss: PTCompressionLoss = FRACBITS_LOSSES.get(loss_type)(
model=target_model, compression_rate=compression_rate, criteria=criteria)

def _set_scheduler(self):
algo_config = self._get_algo_config()

freeze_epoch = algo_config.get("freeze_epoch", None)

if freeze_epoch is None:
raise RuntimeError("You didn't set freeze_epoch config.")

def _callback():
self.freeze_bit_widths()

self._scheduler = FracBitsQuantizationScheduler(
freeze_callback=_callback,
freeze_epoch=freeze_epoch
)

def _get_algo_config(self) -> Dict:
return extract_algo_specific_config(self.config, algo_name_to_match="fracbits_quantization")

def freeze_bit_widths(self):
for q in self.all_quantizations.values():
q.freeze_num_bits()

def statistics(self, quickly_collected_only=False) -> NNCFStatistics:
@contextmanager
def _base_name_context():
tmp_name = self._name
self._name = "quantization"
yield self.name
self._name = tmp_name

with _base_name_context():
nncf_statistics = super().statistics(quickly_collected_only)

nncf_statistics.register(self.name, FracBitsStatistics(self._loss.get_state()))

return nncf_statistics
101 changes: 101 additions & 0 deletions nncf/experimental/torch/fracbits/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from numbers import Number
from typing import Dict, Union
import torch
from nncf.common.utils.registry import Registry
from nncf.torch.compression_method_api import PTCompressionLoss
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_network import NNCFNetwork
from torch import nn
from dataclasses import dataclass
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.common.utils.logger import logger as nncf_logger


FRACBITS_LOSSES = Registry("fracbits_loss")
EPS = 1e-6


@dataclass
class ModuleQuantizerPair:
module: nn.Module
quantizer: BaseQuantizer


@FRACBITS_LOSSES.register("model_size")
class ModelSizeCompressionLoss(PTCompressionLoss):
def __init__(self, model: NNCFNetwork, compression_rate: float, criteria: str = "L1", **kwargs):
super().__init__()
self._model = model
self._compression_rate = compression_rate
self._criteria = self._get_criteria(criteria)

self._w_q_pairs: Dict[str, ModuleQuantizerPair] = {}

for name, module in self._model.named_modules():
if isinstance(module, UpdateWeight):
parent_name = ".".join(name.split(".")[:-2])
parent_module = self._model.get_submodule(parent_name)

self._w_q_pairs[parent_name] = ModuleQuantizerPair(parent_module, module.op)

with torch.no_grad():
self._init_model_size = self._get_model_size()

def calculate(self) -> torch.Tensor:
cur_comp_rate = self._init_model_size / (self._get_model_size() + EPS)
tgt_comp_rate = torch.full_like(cur_comp_rate, self._compression_rate)

return self._criteria(cur_comp_rate, tgt_comp_rate)

def _get_criteria(self, criteria) -> nn.modules.loss._Loss:
if criteria == "L1":
return nn.L1Loss()
if criteria == "L2":
return nn.MSELoss()
raise RuntimeError(f"Unknown criteria = {criteria}.")

def _get_model_size(self) -> Union[torch.Tensor, Number]:
def _get_module_size(module: nn.Module, num_bits: Union[int, torch.Tensor]) -> Union[torch.Tensor, Number]:
if isinstance(module, (nn.modules.conv._ConvNd, nn.Linear)):
return (module.weight.shape.numel() * num_bits).sum()
nncf_logger.warning("module={module} is not supported by ModelSizeCompressionLoss. Skip it.")
return 0.

return sum([_get_module_size(pair.module, pair.quantizer.frac_num_bits) for pair in self._w_q_pairs.values()])

@torch.no_grad()
def get_state(self) -> Dict[str, Number]:
states = {
"compression_rate": self._init_model_size / (self._get_model_size() + EPS).item()
}

for name, pair in self._w_q_pairs.items():
states[f"frac_bits/{name}"] = pair.quantizer.frac_num_bits.item()

return states


@FRACBITS_LOSSES.register("bitops")
class BitOpsCompressionLoss(PTCompressionLoss):
def __init__(self):
super().__init__()

def calculate(self) -> torch.Tensor:
raise NotImplementedError()

@torch.no_grad()
def get_state(self) -> Dict[str, Number]:
raise NotImplementedError()
Loading

0 comments on commit 7bd7b87

Please sign in to comment.