Skip to content

Commit

Permalink
Update pretrained diffusion model (#233)
Browse files Browse the repository at this point in the history
* use find_unused_parameters=True now necessary for training the DDPM with DDP

* Download updated checkpoint
  • Loading branch information
marksgraham authored Feb 7, 2023
1 parent 69118fa commit eeeca8f
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 100\n",
" val_interval = 10\n",
Expand Down Expand Up @@ -1096,7 +1096,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -207,7 +207,7 @@
use_pretrained = False

if use_pretrained:
model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 100
val_interval = 10
Expand Down
4 changes: 2 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_inpainting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 50\n",
" val_interval = 5\n",
Expand Down Expand Up @@ -914,7 +914,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -191,7 +191,7 @@
use_pretrained = False

if use_pretrained:
model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 50
val_interval = 5
Expand Down
15 changes: 13 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"execution_count": 2,
"id": "dd62a552",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -137,6 +138,7 @@
"execution_count": 3,
"id": "8fc58c80",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -169,6 +171,7 @@
"execution_count": 4,
"id": "ad5a1948",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand All @@ -194,6 +197,7 @@
"execution_count": 5,
"id": "65e1c200",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -232,6 +236,7 @@
"execution_count": 6,
"id": "e2f9bebd",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -271,6 +276,7 @@
"execution_count": 7,
"id": "938318c2",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -320,6 +326,7 @@
"execution_count": 8,
"id": "b698f4f8",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -372,6 +379,7 @@
"execution_count": 9,
"id": "2c52e4f4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -415,6 +423,7 @@
"execution_count": 10,
"id": "0f697a13",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -763,7 +772,7 @@
"use_pretrained = False\n",
"\n",
"if use_pretrained:\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models\", model=\"ddpm_2d\", verbose=True).to(device)\n",
" model = torch.hub.load(\"marksgraham/pretrained_generative_models:v0.2\", model=\"ddpm_2d\", verbose=True).to(device)\n",
"else:\n",
" n_epochs = 75\n",
" val_interval = 5\n",
Expand Down Expand Up @@ -852,6 +861,7 @@
"execution_count": 11,
"id": "2cdcda81",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -901,6 +911,7 @@
"execution_count": 12,
"id": "1427e5d4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -984,7 +995,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.4
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -190,7 +190,7 @@
use_pretrained = False

if use_pretrained:
model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
model = torch.hub.load("marksgraham/pretrained_generative_models:v0.2", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 75
val_interval = 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def main_worker(args):

inferer = DiffusionInferer(scheduler)
# wrap the model with DistributedDataParallel module
model = DistributedDataParallel(model, device_ids=[device])
model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)

# start a typical PyTorch training
best_metric = 10000
Expand Down

0 comments on commit eeeca8f

Please sign in to comment.