Skip to content

Commit

Permalink
Update pretrained_weights.ipynb
Browse files Browse the repository at this point in the history
Fixed an error in the state dict loading of the turorial and added a comment on the num_classes parameter when creating timm models.
  • Loading branch information
konstantinklemmer authored and calebrob6 committed Dec 2, 2023
1 parent e8e99f0 commit 1e19d7d
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions docs/tutorials/pretrained_weights.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,26 @@
"outputs": [],
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=10)\n",
"model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=0)\n",
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setting `num_classes=0` will prevent the creation of a prediction head (fully-connected layer). However, you can create a timm model with a prediction head and still match the keys of all but the last fully-connected layer, the parameters of which will be randomly initialized, when using the `strict=False` flag."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"in_chans = weights.meta[\"in_chans\"]\n",
"model = timm.create_model(\"resnet18\", in_chans=in_chans, num_classes=256)\n",
"model.load_state_dict(weights.get_state_dict(progress=True), strict=False)"
]
},
{
Expand Down

0 comments on commit 1e19d7d

Please sign in to comment.