Skip to content

Commit

Permalink
fix docstring and add more test cases for multiplier
Browse files Browse the repository at this point in the history
  • Loading branch information
guopengf committed Aug 11, 2023
1 parent ce72651 commit 01ff8dc
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
6 changes: 3 additions & 3 deletions generative/networks/blocks/encoder_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class SpatialRescaler(nn.Module):
n_stages: number of interpolation stages.
size: output spatial size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]).
method: algorithm used for sampling.
multiplier: multiplier for spatial size. If scale_factor is a tuple,
its length has to match the number of spatial dimensions.
multiplier: multiplier for spatial size. If `multiplier` is a sequence,
its length has to match the number of spatial dimensions; `input.dim() - 2`.
in_channels: number of input channels.
out_channels: number of output channels.
bias: whether to have a bias term.
Expand All @@ -43,7 +43,7 @@ def __init__(
n_stages: int = 1,
size: Sequence[int] | int | None = None,
method: str = "bilinear",
multiplier: float | None = None,
multiplier: Sequence[float] | float | None = None,
in_channels: int = 3,
out_channels: int = None,
bias: bool = False,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_encoder_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@
(1, 3, 16, 16, 16),
(1, 2, 8, 8, 8),
],
[
{
"spatial_dims": 3,
"n_stages": 1,
"method": "trilinear",
"multiplier": (0.25, 0.5, 0.75),
"in_channels": 3,
"out_channels": 2,
},
(1, 3, 20, 20, 20),
(1, 2, 5, 10, 15),
],
[
{"spatial_dims": 2, "n_stages": 1, "size": (8, 8), "method": "bilinear", "in_channels": 3, "out_channels": 2},
(1, 3, 16, 16),
Expand Down

0 comments on commit 01ff8dc

Please sign in to comment.