Skip to content

Commit

Permalink
add spatial rescaler (#414)
Browse files Browse the repository at this point in the history
* add spatial rescaler

* fix import path

* add unittests

* fix test name

* support size argument

* fix docstring and add more test cases for multiplier

* fix error message
  • Loading branch information
guopengf authored Aug 14, 2023
1 parent 4547ca4 commit 11e6323
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
1 change: 1 addition & 0 deletions generative/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@

from __future__ import annotations

from .encoder_modules import SpatialRescaler
from .selfattention import SABlock
from .transformerblock import TransformerBlock
83 changes: 83 additions & 0 deletions generative/networks/blocks/encoder_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from functools import partial

import torch
import torch.nn as nn
from monai.networks.blocks import Convolution

__all__ = ["SpatialRescaler"]


class SpatialRescaler(nn.Module):
"""
SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py
Args:
spatial_dims: number of spatial dimensions.
n_stages: number of interpolation stages.
size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]).
method: algorithm used for sampling.
multiplier: multiplier for spatial size. If `multiplier` is a sequence,
its length has to match the number of spatial dimensions; `input.dim() - 2`.
in_channels: number of input channels.
out_channels: number of output channels.
bias: whether to have a bias term.
"""

def __init__(
self,
spatial_dims: int = 2,
n_stages: int = 1,
size: Sequence[int] | int | None = None,
method: str = "bilinear",
multiplier: Sequence[float] | float | None = None,
in_channels: int = 3,
out_channels: int = None,
bias: bool = False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"]
if size is not None and n_stages != 1:
raise ValueError("when size is not None, n_stages should be 1.")
if size is not None and multiplier is not None:
raise ValueError("only one of size or multiplier should be defined.")
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size)
self.remap_output = out_channels is not None
if self.remap_output:
print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.")
self.channel_mapper = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
conv_only=True,
bias=bias,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.remap_output:
x = self.channel_mapper(x)

for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)

return x

def encode(self, x: torch.Tensor) -> torch.Tensor:
return self(x)
130 changes: 130 additions & 0 deletions tests/test_encoder_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from generative.networks.blocks import SpatialRescaler

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

CASES = [
[
{
"spatial_dims": 2,
"n_stages": 1,
"method": "bilinear",
"multiplier": 0.5,
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16),
(1, 1, 8, 8),
],
[
{
"spatial_dims": 2,
"n_stages": 1,
"method": "bilinear",
"multiplier": 0.5,
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 16, 16),
(1, 2, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": 0.5,
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16, 16),
(1, 1, 8, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": 0.5,
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 16, 16, 16),
(1, 2, 8, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": (0.25, 0.5, 0.75),
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 20, 20, 20),
(1, 2, 5, 10, 15),
],
[
{"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2},
(1, 3, 16, 16),
(1, 2, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"size": (8, 8, 8),
"method": "trilinear",
"in_channels": None,
"out_channels": None,
},
(1, 1, 16, 16, 16),
(1, 1, 8, 8, 8),
],
]


class TestSpatialRescaler(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
module = SpatialRescaler(**input_param).to(device)

result = module(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

def test_method_not_in_available_options(self):
with self.assertRaises(AssertionError):
SpatialRescaler(method="none")

def test_n_stages_is_negative(self):
with self.assertRaises(AssertionError):
SpatialRescaler(n_stages=-1)

def test_use_size_but_n_stages_is_not_one(self):
with self.assertRaises(ValueError):
SpatialRescaler(n_stages=2, size=[8, 8, 8])

def test_both_size_and_multiplier_defined(self):
with self.assertRaises(ValueError):
SpatialRescaler(size=[1, 2, 3], multiplier=0.5)


if __name__ == "__main__":
unittest.main()

0 comments on commit 11e6323

Please sign in to comment.