Skip to content

Commit

Permalink
Merge pull request #305 from Project-MONAI/304-move-preparebatch
Browse files Browse the repository at this point in the history
Moving DiffusionPrepareBatch
  • Loading branch information
ericspod authored Mar 13, 2023
2 parents 9e95e54 + ff5497a commit 0477a22
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 172 deletions.
1 change: 1 addition & 0 deletions generative/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@

from __future__ import annotations

from .prepare_batch import DiffusionPrepareBatch, VPredictionPrepareBatch
from .trainer import AdversarialTrainer
92 changes: 92 additions & 0 deletions generative/engines/prepare_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Dict, Mapping, Optional, Union

import torch
import torch.nn as nn
from monai.engines import PrepareBatch, default_prepare_batch


class DiffusionPrepareBatch(PrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
"""

def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images: torch.Tensor) -> torch.Tensor:
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
"""Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
"""Return the target for the loss function, this is the `noise` value by default."""
return noise

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
infer_kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
infer_kwargs["conditioning"] = batchdata[self.condition_name].to(
device, non_blocking=non_blocking, **kwargs
)

# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
return images, target, (), infer_kwargs


class VPredictionPrepareBatch(DiffusionPrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
from this compute the velocity using the provided scheduler. This value is used as the target in place of the
noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
being used in conjunction with this class expects a "noise" parameter to be provided.
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
"""

def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: Optional[str] = None) -> None:
super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
self.scheduler = scheduler

def get_target(self, images, noise, timesteps):
return self.scheduler.get_velocity(images, noise, timesteps)
1 change: 0 additions & 1 deletion generative/networks/nets/patchgan_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
dropout: float | tuple = 0.0,
last_conv_kernel_size: int | None = None,
) -> None:

super().__init__()
self.num_layers_d = num_layers_d
self.num_channels = num_channels
Expand Down
2 changes: 1 addition & 1 deletion model-zoo/models/mednist_ddpm/bundle/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ optimizer:
lr: '@lr'

prepare_batch:
_target_: scripts.DiffusionPrepareBatch
_target_: generative.engines.DiffusionPrepareBatch
num_train_timesteps: '@num_train_timesteps'

val_handlers:
Expand Down
49 changes: 0 additions & 49 deletions model-zoo/models/mednist_ddpm/bundle/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,5 @@
from __future__ import annotations

from typing import Dict, Mapping, Optional, Union

import torch
from monai.engines import PrepareBatch, default_prepare_batch


class DiffusionPrepareBatch(PrepareBatch):
"""
This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
"""

def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
self.condition_name = condition_name
self.num_train_timesteps = num_train_timesteps

def get_noise(self, images: torch.Tensor) -> torch.Tensor:
"""Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
return torch.randn_like(images)

def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Union[str, torch.device] | None = None,
non_blocking: bool = False,
**kwargs,
):
images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)

kwargs = {"noise": noise, "timesteps": timesteps}

if self.condition_name is not None and isinstance(batchdata, Mapping):
kwargs["conditioning"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)

# return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
return images, noise, (), kwargs


def inv_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
"""
Expand Down
1 change: 0 additions & 1 deletion tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def run_testsuit():


if __name__ == "__main__":

# testing import submodules
from monai.utils.module import load_submodules

Expand Down
1 change: 0 additions & 1 deletion tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def get_default_pattern(loader):


if __name__ == "__main__":

# Parse input arguments
args = parse_args()

Expand Down
1 change: 0 additions & 1 deletion tests/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
class TestDiffusionSamplingInferer(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_call(self, model_params, input_shape):

model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down
1 change: 0 additions & 1 deletion tests/test_patch_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def test_too_small_shape(self):
MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])

def test_script(self):

net = MultiScalePatchDiscriminator(
num_d=2,
num_layers_d=3,
Expand Down
1 change: 0 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,6 @@ def run_process(func, args, kwargs, results):
results.put(e)

def __call__(self, obj):

if self.skip_timing:
return obj

Expand Down
74 changes: 12 additions & 62 deletions tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"execution_count": 2,
"id": "dd62a552",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -112,6 +113,7 @@
"from monai.utils import first, set_determinism\n",
"\n",
"from generative.inferers import DiffusionInferer\n",
"from generative.engines import DiffusionPrepareBatch\n",
"\n",
"# TODO: Add right import reference after deployed\n",
"from generative.networks.nets import DiffusionModelUNet\n",
Expand Down Expand Up @@ -139,6 +141,7 @@
"execution_count": 3,
"id": "8fc58c80",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -171,6 +174,7 @@
"execution_count": 4,
"id": "ad5a1948",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand All @@ -196,6 +200,7 @@
"execution_count": 5,
"id": "65e1c200",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -234,6 +239,7 @@
"execution_count": 6,
"id": "e2f9bebd",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -273,6 +279,7 @@
"execution_count": 7,
"id": "938318c2",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -322,6 +329,7 @@
"execution_count": 8,
"id": "b698f4f8",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -374,6 +382,7 @@
"execution_count": 9,
"id": "2c52e4f4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -402,67 +411,6 @@
"inferer = DiffusionInferer(scheduler)"
]
},
{
"cell_type": "markdown",
"id": "655fa0a2-91f7-45e6-b3f8-259b76fe7e74",
"metadata": {},
"source": [
"### Define a class for preparing batches"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "15e46af7-c3e9-409b-ab1f-5884ada2729f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"class DiffusionPrepareBatch(PrepareBatch):\n",
" \"\"\"\n",
" This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.\n",
"\n",
" Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and\n",
" return the image and noise field as the image/target pair plus the noise field the kwargs under the key \"noise\".\n",
" This assumes the inferer being used in conjunction with this class expects a \"noise\" parameter to be provided.\n",
"\n",
" If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition\n",
" field to be passed to the inferer. This will appear in the keyword arguments under the key \"condition\".\n",
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, num_train_timesteps: int, condition_name: Optional[str] = None):\n",
" self.condition_name = condition_name\n",
" self.num_train_timesteps = num_train_timesteps\n",
"\n",
" def get_noise(self, images):\n",
" \"\"\"Returns the noise tensor for input tensor `images`, override this for different noise distributions.\"\"\"\n",
" return torch.randn_like(images)\n",
"\n",
" def get_timesteps(self, images):\n",
" return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()\n",
"\n",
" def __call__(\n",
" self,\n",
" batchdata: Dict[str, torch.Tensor],\n",
" device: Optional[Union[str, torch.device]] = None,\n",
" non_blocking: bool = False,\n",
" **kwargs,\n",
" ):\n",
" images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)\n",
" noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)\n",
" timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" kwargs = {\"noise\": noise, \"timesteps\": timesteps}\n",
"\n",
" if self.condition_name is not None and isinstance(batchdata, Mapping):\n",
" kwargs[\"conditioning\"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)\n",
"\n",
" # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value\n",
" return images, noise, (), kwargs"
]
},
{
"cell_type": "markdown",
"id": "5a316067",
Expand All @@ -477,6 +425,7 @@
"execution_count": 11,
"id": "0f697a13",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -2207,6 +2156,7 @@
"execution_count": 12,
"id": "1427e5d4",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
Expand Down Expand Up @@ -2291,7 +2241,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 0477a22

Please sign in to comment.