Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-pretrained weight support - Quantized ResNet50 #4627

Merged
merged 4 commits into from
Oct 18, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 15, 2021

Resolves #4671

Example usage:

from PIL import Image
from torchvision import prototype as P


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

combos = [
    (P.models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1, True),
    (P.models.ResNet50Weights.ImageNet1K_RefV1, False)
]

for weights, quantize in combos:
    # Initialize model
    print(weights, quantize)
    model = P.models.quantization.resnet50(weights=weights, quantize=quantize)
    model.eval()

    # Initialize inference transforms
    preprocess = weights.transforms()

    # Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
    prediction = model(batch).squeeze(0).softmax(0)

    # Make predictions
    label = prediction.argmax().item()
    score = prediction[label].item()

    # Use meta to get label
    category_name = weights.meta['categories'][label]
    print(f"{category_name}: {100 * score}%")

cc @pmeier @bjuncek

@datumbox datumbox added enhancement module: models.quantization Issues related to the quantizable/quantized models prototype labels Oct 15, 2021
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Providing some clarifications on comments:

@@ -110,7 +110,7 @@ def fuse_model(self) -> None:

def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing in place a previous "bug" on our typing. The correct types here are the more specific Quantizable versions.

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
Copy link
Contributor Author

@datumbox datumbox Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all weights are expected to have the "backend" meta. For example when quantize=False is passed.

Let's ignore the silent overwriting of parameters (see #4613 (comment)). This can be discussed separately and applied everywhere on a follow up.

kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "fbgemm")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend is allowed to be a "hidden" kwargs argument. We can decide making it public on future iterations.



def resnet50(
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
Copy link
Contributor Author

@datumbox datumbox Oct 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow passing both Quantized and normal weights. This is aligned with the past behaviours where different URLs were loaded depending on the value of quantize.

if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of which default weights we load depends on whether we quantize or not.

else:
weights = None

if quantize:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation of the weights also depends on the value of quantize. Passing the wrong combination will throw an error.

@datumbox datumbox requested a review from fmassa October 15, 2021 16:29
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping

Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox!

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I'm approving to unblock. The current structure follows somewhat closely what we do in the main branch, so it's a good start. I've added some comments with some potential ideas to consider in the future.

Also, it would be good to get someone from the quantization team to have a look and see if the weight structure (with backends etc) is still relevant as of today, or if there will be new things coming that we should consider (sparsity, etc).

**kwargs: Any,
) -> QuantizableResNet:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging again #4613 (comment) but we can discuss in a follow-up

}


class QuantizedResNet50Weights(Weights):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing an idea in the wild: as of today (and I think this will be the case for all models), all quantized weights originates from an unquantized weight. Do we want to keep this link somehow in the Weights structure? Do we want the quantized weights to be magically picked from the ResNet50Weights if we pass the quantize flag?

I think it might be important to keep this relationship somehow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could keep this in the meta-data or if we think we are willing to make a commitment pass this as a proper field of the Weights Data Class. The reason we don't is because quantization is not something too mature at the moment. Only classification models have been quantized and we examine alternative APIs/approaches (such as FX) to achieve it. For these reasons, I would be in favour of not introducing a direct link and review this decision on the near future.

quantize: bool = False,
**kwargs: Any,
) -> QuantizableResNet:
if "pretrained" in kwargs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could probably simplify a tiny bit the code by doing something like

if quantize:
    weights_table = QuantizedResNet50Weights
else:
    weights_table = ResNet50Weights
...
weights = weights_table.ImageNet1K_RefV1  # different naming convention than now
weights = weights_table.verify(weights)

In some sense, it's a bit annoying to have to carry those two Weights inside every model builder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided not to connect the two using the same name because the Reference/recipe version might not be the same; aka RefV1 for quantized is something different than for non-quantized. For example let's say you use a different recipe / config to achieve even better quantization of the same weights; now you need a different version of the enum. Moreover there might be multiple quantized weights enums for the same unquantized weights (for example if multiple backends should be supported). Some of these points will become clearer on the near future and we can revisit them.

@datumbox datumbox merged commit c88423b into pytorch:main Oct 18, 2021
@datumbox datumbox deleted the multiweight/quanitzed_resnet50 branch October 18, 2021 12:47
mszhanyi pushed a commit to mszhanyi/vision that referenced this pull request Oct 19, 2021
* Fixing minor issue on typing.

* Sample implementation for quantized resnet50.
facebook-github-bot pushed a commit that referenced this pull request Oct 19, 2021
Summary:
* Fixing minor issue on typing.

* Sample implementation for quantized resnet50.

Reviewed By: NicolasHug

Differential Revision: D31758317

fbshipit-source-id: 087671ad10716bc23e6aefccaf7c7e03cafc190e
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Fixing minor issue on typing.

* Sample implementation for quantized resnet50.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-pretrained weights: Add initial API and basic implementation
5 participants