From dd6d82ea353550b80b2ccacffaf9817217f5f55a Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 16:15:46 -0400 Subject: [PATCH 1/7] add spatial rescaler --- generative/networks/blocks/__init__.py | 1 + generative/networks/blocks/encoder_modules.py | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 generative/networks/blocks/encoder_modules.py diff --git a/generative/networks/blocks/__init__.py b/generative/networks/blocks/__init__.py index 49c4a32c..72e86904 100644 --- a/generative/networks/blocks/__init__.py +++ b/generative/networks/blocks/__init__.py @@ -11,5 +11,6 @@ from __future__ import annotations +from encoder_modules import SpatialRescaler from .selfattention import SABlock from .transformerblock import TransformerBlock diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py new file mode 100644 index 00000000..790fed3a --- /dev/null +++ b/generative/networks/blocks/encoder_modules.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial + +import torch +import torch.nn as nn +from monai.networks.blocks import Convolution + +__all__ = ["SpatialRescaler"] + + +class SpatialRescaler(nn.Module): + """ + SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py + + Args: + n_stages: number of interpolation stages. + method: algorithm used for sampling. + multiplier: multiplier for spatial size. If scale_factor is a tuple, + its length has to match the number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + bias: whether to have a bias term. + """ + + def __init__(self, + spatial_dims: int = 2, + n_stages: int = 1, + method: str = 'bilinear', + multiplier: float = 0.5, + in_channels: int = 3, + out_channels: int = None, + bias: bool = False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + conv_only=True, + bias=bias + ) + + def forward(self, x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) From 66fa5385107278bdf2cfdee5e1089348a43d7e58 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Thu, 27 Jul 2023 16:17:45 -0400 Subject: [PATCH 2/7] fix import path --- generative/networks/blocks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/blocks/__init__.py b/generative/networks/blocks/__init__.py index 72e86904..b7931237 100644 --- a/generative/networks/blocks/__init__.py +++ b/generative/networks/blocks/__init__.py @@ -11,6 +11,6 @@ from __future__ import annotations -from encoder_modules import SpatialRescaler +from .encoder_modules import SpatialRescaler from .selfattention import SABlock from .transformerblock import TransformerBlock From 1a00e9af383b5e12038a8c4097db17be4bfd1107 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Tue, 1 Aug 2023 17:56:05 -0400 Subject: [PATCH 3/7] add unittests --- generative/networks/blocks/encoder_modules.py | 47 +++++----- tests/test_encoder_modules.py | 93 +++++++++++++++++++ 2 files changed, 118 insertions(+), 22 deletions(-) create mode 100644 tests/test_encoder_modules.py diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index 790fed3a..8d0a1b83 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -22,50 +22,53 @@ class SpatialRescaler(nn.Module): """ - SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py + SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py - Args: - n_stages: number of interpolation stages. - method: algorithm used for sampling. - multiplier: multiplier for spatial size. If scale_factor is a tuple, - its length has to match the number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - bias: whether to have a bias term. + Args: + n_stages: number of interpolation stages. + method: algorithm used for sampling. + multiplier: multiplier for spatial size. If scale_factor is a tuple, + its length has to match the number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + bias: whether to have a bias term. """ - def __init__(self, - spatial_dims: int = 2, - n_stages: int = 1, - method: str = 'bilinear', - multiplier: float = 0.5, - in_channels: int = 3, - out_channels: int = None, - bias: bool = False): + def __init__( + self, + spatial_dims: int = 2, + n_stages: int = 1, + method: str = "bilinear", + multiplier: float = 0.5, + in_channels: int = 3, + out_channels: int = None, + bias: bool = False, + ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 - assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] + assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None if self.remap_output: - print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") self.channel_mapper = Convolution( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, kernel_size=1, conv_only=True, - bias=bias + bias=bias, ) def forward(self, x): + if self.remap_output: + x = self.channel_mapper(x) + for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) - if self.remap_output: - x = self.channel_mapper(x) return x def encode(self, x): diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py new file mode 100644 index 00000000..dbfd5455 --- /dev/null +++ b/tests/test_encoder_modules.py @@ -0,0 +1,93 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from generative.networks.blocks import SpatialRescaler + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +CASES = [ + [ + { + "spatial_dims": 2, + "n_stages": 1, + "method": "bilinear", + "multiplier": 0.5, + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16), + (1, 1, 8, 8), + ], + [ + { + "spatial_dims": 2, + "n_stages": 1, + "method": "bilinear", + "multiplier": 0.5, + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 16, 16), + (1, 2, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": 0.5, + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16, 16), + (1, 1, 8, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": 0.5, + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 16, 16, 16), + (1, 2, 8, 8, 8), + ], +] + + +class TestAutoEncoderKL(unittest.TestCase): + @parameterized.expand(CASES) + def test_shape(self, input_param, input_shape, expected_shape): + module = SpatialRescaler(**input_param).to(device) + + result = module(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape) + + def test_method_not_in_available_options(self): + with self.assertRaises(AssertionError): + SpatialRescaler(method="none") + + def test_n_stages_is_negative(self): + with self.assertRaises(AssertionError): + SpatialRescaler(n_stages=-1) + + +if __name__ == "__main__": + unittest.main() From d0b129c869eda761eda1b8af6305e9e96a849ac9 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Wed, 9 Aug 2023 22:19:36 -0400 Subject: [PATCH 4/7] fix test name --- tests/test_encoder_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py index dbfd5455..83d4b0e3 100644 --- a/tests/test_encoder_modules.py +++ b/tests/test_encoder_modules.py @@ -72,7 +72,7 @@ ] -class TestAutoEncoderKL(unittest.TestCase): +class TestSpatialRescaler(unittest.TestCase): @parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): module = SpatialRescaler(**input_param).to(device) From ce72651fbb429b23c072fa868b91a0447bc13c99 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Wed, 9 Aug 2023 23:00:01 -0400 Subject: [PATCH 5/7] support size argument --- generative/networks/blocks/encoder_modules.py | 16 +++++++++--- tests/test_encoder_modules.py | 25 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index 8d0a1b83..cf8fa8b3 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -11,6 +11,7 @@ from __future__ import annotations +from collections.abc import Sequence from functools import partial import torch @@ -25,7 +26,9 @@ class SpatialRescaler(nn.Module): SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py Args: + spatial_dims: number of spatial dimensions. n_stages: number of interpolation stages. + size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). method: algorithm used for sampling. multiplier: multiplier for spatial size. If scale_factor is a tuple, its length has to match the number of spatial dimensions. @@ -38,8 +41,9 @@ def __init__( self, spatial_dims: int = 2, n_stages: int = 1, + size: Sequence[int] | int | None = None, method: str = "bilinear", - multiplier: float = 0.5, + multiplier: float | None = None, in_channels: int = 3, out_channels: int = None, bias: bool = False, @@ -48,8 +52,12 @@ def __init__( self.n_stages = n_stages assert self.n_stages >= 0 assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] + if size is not None and n_stages != 1: + raise ValueError("when size is not None, n_stages should be 1.") + if size is not None and multiplier is not None: + raise ValueError("only one of size or scale_factor should be defined.") self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size) self.remap_output = out_channels is not None if self.remap_output: print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") @@ -62,7 +70,7 @@ def __init__( bias=bias, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.remap_output: x = self.channel_mapper(x) @@ -71,5 +79,5 @@ def forward(self, x): return x - def encode(self, x): + def encode(self, x: torch.Tensor) -> torch.Tensor: return self(x) diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py index 83d4b0e3..74ac4703 100644 --- a/tests/test_encoder_modules.py +++ b/tests/test_encoder_modules.py @@ -69,6 +69,23 @@ (1, 3, 16, 16, 16), (1, 2, 8, 8, 8), ], + [ + {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2}, + (1, 3, 16, 16), + (1, 2, 8, 8), + ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "size": (8, 8, 8), + "method": "trilinear", + "in_channels": None, + "out_channels": None, + }, + (1, 1, 16, 16, 16), + (1, 1, 8, 8, 8), + ], ] @@ -88,6 +105,14 @@ def test_n_stages_is_negative(self): with self.assertRaises(AssertionError): SpatialRescaler(n_stages=-1) + def test_use_size_but_n_stages_is_not_one(self): + with self.assertRaises(ValueError): + SpatialRescaler(n_stages=2, size=[8, 8, 8]) + + def test_both_size_and_multiplier_defined(self): + with self.assertRaises(ValueError): + SpatialRescaler(size=[1, 2, 3], multiplier=0.5) + if __name__ == "__main__": unittest.main() From 01ff8dc52582e16714edb475b990c80b82035d25 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 11 Aug 2023 11:31:13 -0400 Subject: [PATCH 6/7] fix docstring and add more test cases for multiplier --- generative/networks/blocks/encoder_modules.py | 6 +++--- tests/test_encoder_modules.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index cf8fa8b3..ab42dad1 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -30,8 +30,8 @@ class SpatialRescaler(nn.Module): n_stages: number of interpolation stages. size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). method: algorithm used for sampling. - multiplier: multiplier for spatial size. If scale_factor is a tuple, - its length has to match the number of spatial dimensions. + multiplier: multiplier for spatial size. If `multiplier` is a sequence, + its length has to match the number of spatial dimensions; `input.dim() - 2`. in_channels: number of input channels. out_channels: number of output channels. bias: whether to have a bias term. @@ -43,7 +43,7 @@ def __init__( n_stages: int = 1, size: Sequence[int] | int | None = None, method: str = "bilinear", - multiplier: float | None = None, + multiplier: Sequence[float] | float | None = None, in_channels: int = 3, out_channels: int = None, bias: bool = False, diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py index 74ac4703..04639177 100644 --- a/tests/test_encoder_modules.py +++ b/tests/test_encoder_modules.py @@ -69,6 +69,18 @@ (1, 3, 16, 16, 16), (1, 2, 8, 8, 8), ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": (0.25, 0.5, 0.75), + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 20, 20, 20), + (1, 2, 5, 10, 15), + ], [ {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2}, (1, 3, 16, 16), From b6ac768ad41b044791d802952cc27cbf2464a94a Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 11 Aug 2023 11:37:46 -0400 Subject: [PATCH 7/7] fix error message --- generative/networks/blocks/encoder_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index ab42dad1..62eab739 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -55,7 +55,7 @@ def __init__( if size is not None and n_stages != 1: raise ValueError("when size is not None, n_stages should be 1.") if size is not None and multiplier is not None: - raise ValueError("only one of size or scale_factor should be defined.") + raise ValueError("only one of size or multiplier should be defined.") self.multiplier = multiplier self.interpolator = partial(torch.nn.functional.interpolate, mode=method, size=size) self.remap_output = out_channels is not None