Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FracBitsQuantizationBuilder and Controller #1234

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nncf/common/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ def register(self, algorithm_name: str, stats: Statistics):
- quantization
- filter_pruning
- binarization
- fracbits_quantization
:param stats: Statistics of the algorithm.
"""

available_algorithms = [
'magnitude_sparsity', 'rb_sparsity', 'const_sparsity',
'quantization', 'filter_pruning', 'binarization'
'quantization', 'filter_pruning', 'binarization', "fracbits_quantization"
]
if algorithm_name not in available_algorithms:
raise ValueError('Can not register statistics for the algorithm. '
Expand Down
22 changes: 21 additions & 1 deletion nncf/config/experimental_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,31 @@
"additionalProperties": False
}

########################################################################################################################
# FracBits Quantization
########################################################################################################################
FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG = 'fracbits_quantization'
FRACBITS_QUANTIZATION_SCHEMA = copy.deepcopy(QUANTIZATION_SCHEMA)
FRACBITS_QUANTIZATION_SCHEMA['properties']['algorithm']['const'] = FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG
FRACBITS_QUANTIZATION_SCHEMA['properties']['freeze_epoch'] = with_attributes(
NUMBER, description="The number of epoch to freeze fractional bit widths to integers by rounding them.")
FRACBITS_QUANTIZATION_SCHEMA['properties']['loss'] = {
"type": "object",
"properties": {
"type": with_attributes(STRING, description="Type of compression loss. Choose model_size or bitops."),
"compression_rate": with_attributes(NUMBER, description="Target compression rate"),
"criteria": with_attributes(STRING, description="Criteria to measure the distance between the target "
"compression rate and the currrent compression rate. Choose L1 or L2."),
},
"additionalProperties": False
}

########################################################################################################################
# All experimental schemas
########################################################################################################################

EXPERIMENTAL_REF_VS_ALGO_SCHEMA = {
EXPERIMENTAL_QUANTIZATION_ALGO_NAME_IN_CONFIG: EXPERIMENTAL_QUANTIZATION_SCHEMA,
BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA
BOOTSTRAP_NAS_ALGO_NAME_IN_CONFIG: BOOTSTRAP_NAS_SCHEMA,
FRACBITS_QUANTIZATION_ALGO_NAME_IN_CONFIG: FRACBITS_QUANTIZATION_SCHEMA
}
49 changes: 49 additions & 0 deletions nncf/experimental/torch/fracbits/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from nncf.experimental.torch.fracbits.controller import FracBitsQuantizationController
from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
from nncf.torch.compression_method_api import PTCompressionAlgorithmController
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.algo import QuantizationBuilder
from nncf.torch.quantization.layers import PTQuantizerSetup
from nncf.common.quantization.structs import QuantizationMode
from nncf.experimental.torch.fracbits.quantizer import FracBitsQuantizationMode


@PT_COMPRESSION_ALGORITHMS.register('fracbits_quantization')
class FracBitsQuantizationBuilder(QuantizationBuilder):
def _get_quantizer_setup(self, target_model: NNCFNetwork) -> PTQuantizerSetup:
setup = super()._get_quantizer_setup(target_model)

for q_point in setup.quantization_points.values():
mode = q_point.qspec.mode
if mode == QuantizationMode.ASYMMETRIC:
q_point.qspec.mode = FracBitsQuantizationMode.ASYMMETRIC
elif mode == QuantizationMode.SYMMETRIC:
q_point.qspec.mode = FracBitsQuantizationMode.SYMMETRIC
else:
raise ValueError(f"qsepc.mode={mode} is unknown.")

return setup

def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
return FracBitsQuantizationController(model,
self.config,
self._debug_interface,
self._weight_quantizers,
self._non_weight_quantizers,
self._groups_of_adjacent_quantizers,
self._quantizers_input_shapes,
build_time_metric_info=self._build_time_metric_infos,
build_time_range_init_params=self._range_init_params)
109 changes: 109 additions & 0 deletions nncf/experimental/torch/fracbits/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from contextlib import contextmanager
from typing import Dict, Tuple

from nncf.common.quantization.structs import NonWeightQuantizerId, QuantizerId, WeightQuantizerId
from nncf.common.statistics import NNCFStatistics
from nncf.config.config import NNCFConfig
from nncf.config.extractors import extract_algo_specific_config
from nncf.experimental.torch.fracbits.statistics import FracBitsStatistics
from nncf.experimental.torch.fracbits.scheduler import FracBitsQuantizationScheduler
from nncf.torch.compression_method_api import PTCompressionLoss
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.algo import QuantizationController, QuantizationDebugInterface
from nncf.torch.quantization.init_range import PTRangeInitParams
from nncf.torch.quantization.metrics import QuantizationShareBuildTimeInfo
from nncf.torch.quantization.precision_init.adjacent_quantizers import GroupsOfAdjacentQuantizers
from nncf.torch.quantization.structs import NonWeightQuantizerInfo, WeightQuantizerInfo
from nncf.experimental.torch.fracbits.loss import FRACBITS_LOSSES


class FracBitsQuantizationController(QuantizationController):
def __init__(self, target_model: NNCFNetwork,
config: NNCFConfig,
debug_interface: QuantizationDebugInterface,
weight_quantizers: Dict[WeightQuantizerId, WeightQuantizerInfo],
non_weight_quantizers: Dict[NonWeightQuantizerId, NonWeightQuantizerInfo],
groups_of_adjacent_quantizers: GroupsOfAdjacentQuantizers,
quantizers_input_shapes: Dict[QuantizerId, Tuple[int]],
build_time_metric_info: QuantizationShareBuildTimeInfo = None,
build_time_range_init_params: PTRangeInitParams = None):
super().__init__(target_model, config, debug_interface, weight_quantizers, non_weight_quantizers,
groups_of_adjacent_quantizers, quantizers_input_shapes,
build_time_metric_info, build_time_range_init_params)
self._set_fracbits_loss(target_model)
self._set_scheduler()

def _set_fracbits_loss(self, target_model: NNCFNetwork):
algo_config = self._get_algo_config()

if algo_config.get("loss") is None:
raise RuntimeError("You didn't set loss config.")

loss_config = algo_config.get("loss")

if loss_config.get("type") is None:
raise RuntimeError("You didn't set loss.type config.")

if loss_config.get("compression_rate") is None:
raise RuntimeError("You didn't set compression_rate config.")

if loss_config.get("criteria") is None:
raise RuntimeError("You didn't set criteria config.")

loss_type = loss_config.get("type")
compression_rate = loss_config.get("compression_rate")
criteria = loss_config.get("criteria")

self._loss: PTCompressionLoss = FRACBITS_LOSSES.get(loss_type)(
model=target_model, compression_rate=compression_rate, criteria=criteria)

def _set_scheduler(self):
algo_config = self._get_algo_config()

freeze_epoch = algo_config.get("freeze_epoch", None)

if freeze_epoch is None:
raise RuntimeError("You didn't set freeze_epoch config.")

def _callback():
self.freeze_bit_widths()

self._scheduler = FracBitsQuantizationScheduler(
freeze_callback=_callback,
freeze_epoch=freeze_epoch
)

def _get_algo_config(self) -> Dict:
return extract_algo_specific_config(self.config, algo_name_to_match="fracbits_quantization")

def freeze_bit_widths(self):
for q in self.all_quantizations.values():
q.freeze_num_bits()

def statistics(self, quickly_collected_only=False) -> NNCFStatistics:
@contextmanager
def _base_name_context():
tmp_name = self._name
self._name = "quantization"
yield self.name
self._name = tmp_name

with _base_name_context():
nncf_statistics = super().statistics(quickly_collected_only)

nncf_statistics.register(self.name, FracBitsStatistics(self._loss.get_state()))

return nncf_statistics
101 changes: 101 additions & 0 deletions nncf/experimental/torch/fracbits/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from numbers import Number
from typing import Dict, Union
import torch
from nncf.common.utils.registry import Registry
from nncf.torch.compression_method_api import PTCompressionLoss
from nncf.torch.module_operations import UpdateWeight
from nncf.torch.nncf_network import NNCFNetwork
from torch import nn
from dataclasses import dataclass
from nncf.torch.quantization.layers import BaseQuantizer
from nncf.common.utils.logger import logger as nncf_logger


FRACBITS_LOSSES = Registry("fracbits_loss")
EPS = 1e-6


@dataclass
class ModuleQuantizerPair:
module: nn.Module
quantizer: BaseQuantizer


@FRACBITS_LOSSES.register("model_size")
class ModelSizeCompressionLoss(PTCompressionLoss):
def __init__(self, model: NNCFNetwork, compression_rate: float, criteria: str = "L1", **kwargs):
super().__init__()
self._model = model
self._compression_rate = compression_rate
self._criteria = self._get_criteria(criteria)

self._w_q_pairs: Dict[str, ModuleQuantizerPair] = {}

for name, module in self._model.named_modules():
if isinstance(module, UpdateWeight):
parent_name = ".".join(name.split(".")[:-2])
parent_module = self._model.get_submodule(parent_name)

self._w_q_pairs[parent_name] = ModuleQuantizerPair(parent_module, module.op)

with torch.no_grad():
self._init_model_size = self._get_model_size()

def calculate(self) -> torch.Tensor:
cur_comp_rate = self._init_model_size / (self._get_model_size() + EPS)
tgt_comp_rate = torch.full_like(cur_comp_rate, self._compression_rate)

return self._criteria(cur_comp_rate, tgt_comp_rate)

def _get_criteria(self, criteria) -> nn.modules.loss._Loss:
if criteria == "L1":
return nn.L1Loss()
if criteria == "L2":
return nn.MSELoss()
raise RuntimeError(f"Unknown criteria = {criteria}.")

def _get_model_size(self) -> Union[torch.Tensor, Number]:
def _get_module_size(module: nn.Module, num_bits: Union[int, torch.Tensor]) -> Union[torch.Tensor, Number]:
if isinstance(module, (nn.modules.conv._ConvNd, nn.Linear)):
return (module.weight.shape.numel() * num_bits).sum()
wonjuleee marked this conversation as resolved.
Show resolved Hide resolved
nncf_logger.warning("module={module} is not supported by ModelSizeCompressionLoss. Skip it.")
return 0.

return sum([_get_module_size(pair.module, pair.quantizer.frac_num_bits) for pair in self._w_q_pairs.values()])

@torch.no_grad()
def get_state(self) -> Dict[str, Number]:
states = {
"compression_rate": self._init_model_size / (self._get_model_size() + EPS).item()
}

for name, pair in self._w_q_pairs.items():
states[f"frac_bits/{name}"] = pair.quantizer.frac_num_bits.item()

return states


@FRACBITS_LOSSES.register("bitops")
class BitOpsCompressionLoss(PTCompressionLoss):
def __init__(self):
super().__init__()

def calculate(self) -> torch.Tensor:
raise NotImplementedError()

@torch.no_grad()
def get_state(self) -> Dict[str, Number]:
raise NotImplementedError()
8 changes: 6 additions & 2 deletions nncf/experimental/torch/fracbits/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
from typing import Dict
import torch

from nncf.experimental.torch.fracbits.structs import FracBitsQuantizationMode

from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter
from nncf.torch.quantization.layers import (
QUANTIZATION_MODULES, AsymmetricQuantizer, PTQuantizerSpec, SymmetricQuantizer)
from nncf.torch.quantization.quantize_functions import asymmetric_quantize, symmetric_quantize
from nncf.torch.utils import no_jit_trace
from nncf.common.quantization.structs import QuantizationMode


class FracBitsQuantizationMode(QuantizationMode):
SYMMETRIC = 'fracbits_symmetric'
ASYMMETRIC = 'fracbits_asymmetric'


@COMPRESSION_MODULES.register()
Expand Down
28 changes: 28 additions & 0 deletions nncf/experimental/torch/fracbits/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Callable
from nncf.common.schedulers import BaseCompressionScheduler
from nncf.common.utils.logger import logger as nncf_logger

class FracBitsQuantizationScheduler(BaseCompressionScheduler):
def __init__(self, freeze_epoch: int, freeze_callback: Callable):
super().__init__()
self._freeze_epoch = freeze_epoch
self._freeze_callback = freeze_callback

def epoch_step(self, next_epoch=None):
super().epoch_step(next_epoch)
if self._current_epoch == self._freeze_epoch:
nncf_logger.info(f"Current epoch is {self._current_epoch}. Freeze fractional bit widths.")
self._freeze_callback()
35 changes: 35 additions & 0 deletions nncf/experimental/torch/fracbits/statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Copyright (c) 2022 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from numbers import Number
from typing import Dict
from nncf.api.statistics import Statistics

from nncf.common.utils.tensorboard import convert_to_dict


class FracBitsStatistics(Statistics):
def __init__(self, states: Dict[str, Number]) -> None:
super().__init__()
self.data = states

def to_str(self) -> str:
return str(self.data)


@convert_to_dict.register(FracBitsStatistics)
def _convert_to_dict(stats: FracBitsStatistics, algorithm_name: str):
tensorboard_stats = {
algorithm_name + "/" + k: v for k, v in stats.data.items()
}
return tensorboard_stats
Loading