-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add spatial rescaler #414
Merged
Merged
add spatial rescaler #414
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
dd6d82e
add spatial rescaler
guopengf 66fa538
fix import path
guopengf 1a00e9a
add unittests
guopengf d0b129c
fix test name
guopengf ce72651
support size argument
guopengf 01ff8dc
fix docstring and add more test cases for multiplier
guopengf b6ac768
fix error message
guopengf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from functools import partial | ||
|
||
import torch | ||
import torch.nn as nn | ||
from monai.networks.blocks import Convolution | ||
|
||
__all__ = ["SpatialRescaler"] | ||
|
||
|
||
class SpatialRescaler(nn.Module): | ||
""" | ||
SpatialRescaler based on https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/encoders/modules.py | ||
|
||
Args: | ||
n_stages: number of interpolation stages. | ||
method: algorithm used for sampling. | ||
multiplier: multiplier for spatial size. If scale_factor is a tuple, | ||
its length has to match the number of spatial dimensions. | ||
in_channels: number of input channels. | ||
out_channels: number of output channels. | ||
bias: whether to have a bias term. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
spatial_dims: int = 2, | ||
n_stages: int = 1, | ||
method: str = "bilinear", | ||
multiplier: float = 0.5, | ||
marksgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
in_channels: int = 3, | ||
out_channels: int = None, | ||
bias: bool = False, | ||
): | ||
super().__init__() | ||
self.n_stages = n_stages | ||
assert self.n_stages >= 0 | ||
assert method in ["nearest", "linear", "bilinear", "trilinear", "bicubic", "area"] | ||
self.multiplier = multiplier | ||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) | ||
self.remap_output = out_channels is not None | ||
if self.remap_output: | ||
print(f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels before resizing.") | ||
self.channel_mapper = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=in_channels, | ||
out_channels=out_channels, | ||
kernel_size=1, | ||
conv_only=True, | ||
bias=bias, | ||
) | ||
|
||
def forward(self, x): | ||
marksgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.remap_output: | ||
x = self.channel_mapper(x) | ||
|
||
for stage in range(self.n_stages): | ||
x = self.interpolator(x, scale_factor=self.multiplier) | ||
|
||
return x | ||
|
||
def encode(self, x): | ||
marksgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from generative.networks.blocks import SpatialRescaler | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
CASES = [ | ||
[ | ||
{ | ||
"spatial_dims": 2, | ||
"n_stages": 1, | ||
"method": "bilinear", | ||
"multiplier": 0.5, | ||
"in_channels": None, | ||
"out_channels": None, | ||
}, | ||
(1, 1, 16, 16), | ||
(1, 1, 8, 8), | ||
], | ||
[ | ||
{ | ||
"spatial_dims": 2, | ||
"n_stages": 1, | ||
"method": "bilinear", | ||
"multiplier": 0.5, | ||
"in_channels": 3, | ||
"out_channels": 2, | ||
}, | ||
(1, 3, 16, 16), | ||
(1, 2, 8, 8), | ||
], | ||
[ | ||
{ | ||
"spatial_dims": 3, | ||
"n_stages": 1, | ||
"method": "trilinear", | ||
"multiplier": 0.5, | ||
"in_channels": None, | ||
"out_channels": None, | ||
}, | ||
(1, 1, 16, 16, 16), | ||
(1, 1, 8, 8, 8), | ||
], | ||
[ | ||
{ | ||
"spatial_dims": 3, | ||
"n_stages": 1, | ||
"method": "trilinear", | ||
"multiplier": 0.5, | ||
"in_channels": 3, | ||
"out_channels": 2, | ||
}, | ||
(1, 3, 16, 16, 16), | ||
(1, 2, 8, 8, 8), | ||
], | ||
] | ||
|
||
|
||
class TestAutoEncoderKL(unittest.TestCase): | ||
marksgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@parameterized.expand(CASES) | ||
def test_shape(self, input_param, input_shape, expected_shape): | ||
module = SpatialRescaler(**input_param).to(device) | ||
|
||
result = module(torch.randn(input_shape).to(device)) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
def test_method_not_in_available_options(self): | ||
with self.assertRaises(AssertionError): | ||
SpatialRescaler(method="none") | ||
|
||
def test_n_stages_is_negative(self): | ||
with self.assertRaises(AssertionError): | ||
SpatialRescaler(n_stages=-1) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part about
scale_factor
doesn't make sense - there is no variable with that name, and multiplier can't be a tuple. maybe just delete the part aboutscale_factor
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out the docstring problem. I have fixed it. This remembers me that the
multiplier
can be aSequence[float]
. I already added this type and a new test case for this situation in the latest commit.GenerativeModels/tests/test_encoder_modules.py
Line 77 in 01ff8dc