Skip to content

Commit

Permalink
Passing the right activation on quantization.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 27, 2021
1 parent 72cecb1 commit e269817
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
1 change: 1 addition & 0 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation
delattr(self, 'activation')
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)

Expand Down
8 changes: 5 additions & 3 deletions torchvision/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

class QuantizableSqueezeExcitation(SElayer):
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["scale_activation"]=nn.Hardswish
super().__init__(*args, **kwargs)
self.skip_mul = nn.quantized.FloatFunctional()

Expand Down Expand Up @@ -80,11 +81,12 @@ def _load_weights(
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool,
strict: bool
) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=strict)


def _mobilenet_v3_model(
Expand All @@ -108,13 +110,13 @@ def _mobilenet_v3_model(
torch.quantization.prepare_qat(model, inplace=True)

if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False)

torch.quantization.convert(model, inplace=True)
model.eval()
else:
if pretrained:
_load_weights(arch, model, model_urls.get(arch, None), progress)
_load_weights(arch, model, model_urls.get(arch, None), progress, True)

return model

Expand Down

0 comments on commit e269817

Please sign in to comment.