From 01ff8dc52582e16714edb475b990c80b82035d25 Mon Sep 17 00:00:00 2001 From: Pengfei Guo Date: Fri, 11 Aug 2023 11:31:13 -0400 Subject: [PATCH] fix docstring and add more test cases for multiplier --- generative/networks/blocks/encoder_modules.py | 6 +++--- tests/test_encoder_modules.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/generative/networks/blocks/encoder_modules.py b/generative/networks/blocks/encoder_modules.py index cf8fa8b3..ab42dad1 100644 --- a/generative/networks/blocks/encoder_modules.py +++ b/generative/networks/blocks/encoder_modules.py @@ -30,8 +30,8 @@ class SpatialRescaler(nn.Module): n_stages: number of interpolation stages. size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]). method: algorithm used for sampling. - multiplier: multiplier for spatial size. If scale_factor is a tuple, - its length has to match the number of spatial dimensions. + multiplier: multiplier for spatial size. If `multiplier` is a sequence, + its length has to match the number of spatial dimensions; `input.dim() - 2`. in_channels: number of input channels. out_channels: number of output channels. bias: whether to have a bias term. @@ -43,7 +43,7 @@ def __init__( n_stages: int = 1, size: Sequence[int] | int | None = None, method: str = "bilinear", - multiplier: float | None = None, + multiplier: Sequence[float] | float | None = None, in_channels: int = 3, out_channels: int = None, bias: bool = False, diff --git a/tests/test_encoder_modules.py b/tests/test_encoder_modules.py index 74ac4703..04639177 100644 --- a/tests/test_encoder_modules.py +++ b/tests/test_encoder_modules.py @@ -69,6 +69,18 @@ (1, 3, 16, 16, 16), (1, 2, 8, 8, 8), ], + [ + { + "spatial_dims": 3, + "n_stages": 1, + "method": "trilinear", + "multiplier": (0.25, 0.5, 0.75), + "in_channels": 3, + "out_channels": 2, + }, + (1, 3, 20, 20, 20), + (1, 2, 5, 10, 15), + ], [ {"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2}, (1, 3, 16, 16),