Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-pretrained weight support - initial API + ResNet50 #4610

Merged
merged 10 commits into from
Oct 14, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 13, 2021

Resolves #4671

Adds multi-pretrained weight support on the existing model builders of TorchVision.

Example usage:

from PIL import Image

from torchvision import prototype as P


img = Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Initialize model
weights = P.models.ResNet50Weights.ImageNet1K_RefV2
model = P.models.resnet50(weights=weights)
model.eval()

# Initialize inference transforms
preprocess = weights.transforms()

# Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)

# Make predictions
label = prediction.argmax().item()
score = prediction[label].item()

# Use meta to get label
category_name = weights.meta['categories'][label]
print(f"{category_name}: {100 * score}%")

cc @datumbox @pmeier @bjuncek

@datumbox datumbox requested a review from fmassa October 13, 2021 13:46
@datumbox datumbox changed the title Prototype models - ResNet50 [WIP] Prototype models - ResNet50 Oct 13, 2021
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go!

meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/issues/3995",
"acc@1": 80.352,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got more still in the oven 🤞

},
)
ImageNet1K_RefV2 = WeightEntry(
url="https://download.pytorch.org/models/resnet50-tmp.pth",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the plan here, to re-upload at some point in the future? Also, how do we plan on keeping the names for the checkpoint files manageable, just rely on the sha256 to differentiate them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the plan here, to re-upload at some point in the future?

Yes indeed, I still got models being trained so I expect that the weights will change. Just wanted to add something here so that we can see how multiple weights work.

how do we plan on keeping the names for the checkpoint files manageable

Good point. This is why I didn't add the sha256 on the temporary model. I don't want to fill the bucket with mess. I expect there will be one final set of weights added here at the end of all training. Since we are on prototype, I consider I can change it at any time.

just rely on the sha256 to differentiate them?

I don't have a preference. We could introduce more descriptive names (perhaps using the same string as the enum name?) or just rely on sha256.

"acc@5": 92.862,
},
)
ImageNet1K_RefV2 = WeightEntry(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming will be important here. ImageNet1K_RefV2 sounds good for a v1, but we should have a webpage in the doc which will break this down nicely. Maybe something to keep in mind, an easy way to gather this information automatically to facilitate generating the documentation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. The plan is to change this once the recipe is finalized. We will need to update this, along with the URL of the recipe (currently pointing to the issue that I got open).


_common_meta = {
"size": (224, 224),
"categories": list(range(1000)), # TODO: torchvision.prototype.datasets.find("ImageNet").info.categories
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier Let me know when you got the ImageNet category class so that I can replace it here.

return F.convert_image_dtype(img, self.dtype)


class ImageNetEval:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier As discussed offline with @fmassa, feel free to modify and take ownership of this component when you land your PR related to the transforms. As long as we get BC results, we can do any change necessary.

@datumbox datumbox force-pushed the prototype/models_resnet50 branch from e9ff413 to 104a05c Compare October 13, 2021 19:31
@datumbox datumbox changed the title [WIP] Prototype models - ResNet50 Prototype models - ResNet50 Oct 14, 2021
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just questions from me :)

torchvision/prototype/models/resnet.py Show resolved Hide resolved
Comment on lines +62 to +63
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a future plan to allow users to get a pretrained model, without needing to manually instanciate a ResNet50Weights weights object? E.g. something like resnet50(weights='pretrained') would always produce the "default pretrained weights" (which could be e.g. the latest version of the weights, or something else)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To give more context to my question: this isn't just about convenience, but also regarding torchhub.

It'd be cool if we could still load models from torchhub using just torch.load('pytorch/vision', 'resnet50', pretrained=SOMETHING)

where SOMETHING doesn't have to be a custom torchvision class like ResNet50Weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely a feature we want to add. See here for a prototype of exactly what you said. I choose not to include it in this prototype to go for the absolute minimal implementation and give time to review the other RFC as a whole.

Concerning your comment to not require access to the Enum object, I think you are hinting TorchHub here. For this use-case if you pass the string name of the enum value it will build it for you. See this. Could you have one more look and let me know if this works for starters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we posted a reply at the same minute (again!). You are currently able to instantiate a model as follows as well:

model = P.models.resnet50(weights="ImageNet1K_RefV2")

Thus I believe on torchhub you will do:

torch.load('pytorch/vision', 'resnet50', weights="ImageNet1K_RefV2")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense

@datumbox datumbox changed the title Prototype models - ResNet50 Multi-pretrained weight support - ResNet50 Oct 14, 2021
@datumbox datumbox changed the title Multi-pretrained weight support - ResNet50 Multi-pretrained weight support - initial API + ResNet50 Oct 14, 2021
@datumbox datumbox merged commit 8fe72d1 into pytorch:main Oct 14, 2021
@datumbox datumbox deleted the prototype/models_resnet50 branch October 14, 2021 09:49
facebook-github-bot pushed a commit that referenced this pull request Oct 14, 2021
)

Summary:
* Adding lightweight API for models.

* Adding resnet50.

* Fix preset

* Add fake categories.

* Fixing mypy.

* Add string=>weight conversion support on Enums.

* Temporarily hardcoding imagenet categories.

* Minor refactoring.

Reviewed By: fmassa

Differential Revision: D31649970

fbshipit-source-id: b4908da7be972c0a19949e75d61f2051e785494c
mszhanyi pushed a commit to mszhanyi/vision that referenced this pull request Oct 19, 2021
* Adding lightweight API for models.

* Adding resnet50.

* Fix preset

* Add fake categories.

* Fixing mypy.

* Add string=>weight conversion support on Enums.

* Temporarily hardcoding imagenet categories.

* Minor refactoring.
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Adding lightweight API for models.

* Adding resnet50.

* Fix preset

* Add fake categories.

* Fixing mypy.

* Add string=>weight conversion support on Enums.

* Temporarily hardcoding imagenet categories.

* Minor refactoring.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-pretrained weights: Add initial API and basic implementation
4 participants