Skip to content

Commit

Permalink
Merge pull request pytorch#749 from raghuramank100/jlin27-quant-tutor…
Browse files Browse the repository at this point in the history
…ials

Fix formatting and clean up tutorial on quantized transfer learning
  • Loading branch information
Jessica Lin authored Dec 6, 2019
2 parents f58558a + a800c77 commit d0ef61b
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions intermediate_source/quantized_transfer_learning_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,11 @@


######################################################################
# Load Data (section not needed as it is covered in the original tutorial)
# Load Data
# ------------------------------------------------------------------------
#
# ..Note :: This section is identical to the original transfer learning tutorial.
#
# We will use ``torchvision`` and ``torch.utils.data`` packages to load
# the data.
#
Expand Down Expand Up @@ -360,7 +362,7 @@ def visualize_model(model, rows=3, cols=3):
# **Notice that when isolating the feature extractor from a quantized
# model, you have to place the quantizer in the beginning and in the end
# of it.**
#
# We write a helper function to create a model with a custom head.

from torch import nn

Expand Down Expand Up @@ -394,8 +396,6 @@ def create_combined_model(model_fe):
)
return new_model

new_model = create_combined_model(model_fe)


######################################################################
# .. warning:: Currently the quantized models can only be run on CPU.
Expand All @@ -404,6 +404,7 @@ def create_combined_model(model_fe):
#

import torch.optim as optim
new_model = create_combined_model(model_fe)
new_model = new_model.to('cpu')

criterion = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -431,7 +432,7 @@ def create_combined_model(model_fe):


######################################################################
# **Part 2. Finetuning the quantizable model**
# Part 2. Finetuning the quantizable model
#
# In this part, we fine tune the feature extractor used for transfer
# learning, and quantize the feature extractor. Note that in both part 1
Expand All @@ -446,18 +447,21 @@ def create_combined_model(model_fe):
# datasets.
#
# The pretrained feature extractor must be quantizable, i.e we need to do
# the following: 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
# using torch.quantization.fuse_modules. 2. Connect the feature extractor
# with a custom head. This requires dequantizing the output of the feature
# extractor. 3. Insert fake-quantization modules at appropriate locations
# in the feature extractor to mimic quantization during training.
# the following:
# 1. Fuse (Conv, BN, ReLU), (Conv, BN) and (Conv, ReLU)
# using torch.quantization.fuse_modules.
# 2. Connect the feature extractor
# with a custom head. This requires dequantizing the output of the feature
# extractor.
# 3. Insert fake-quantization modules at appropriate locations
# in the feature extractor to mimic quantization during training.
#
# For step (1), we use models from torchvision/models/quantization, which
# support a member method fuse_model, which fuses all the conv, bn, and
# relu modules. In general, this would require calling the
# torch.quantization.fuse_modules API with the list of modules to fuse.
#
# Step (2) is done by the function create_custom_model function that we
# Step (2) is done by the function create_combined_model function that we
# used in the previous section.
#
# Step (3) is achieved by using torch.quantization.prepare_qat, which
Expand Down Expand Up @@ -534,4 +538,3 @@ def create_combined_model(model_fe):
plt.ioff()
plt.tight_layout()
plt.show()

0 comments on commit d0ef61b

Please sign in to comment.